fix(cache): pass textOnly to getSessionsConfig so is_pipeline_cached skips vision encoder#1608
fix(cache): pass textOnly to getSessionsConfig so is_pipeline_cached skips vision encoder#1608s-zx wants to merge 2 commits intohuggingface:mainfrom
Conversation
…skips vision encoder is_pipeline_cached incorrectly returns false for text-generation pipelines on models like gemma-3-4b-it-ONNX because get_model_files checks for vision_encoder files that text-generation never downloads. Root cause: getSessionsConfig only forwarded two arguments to the sessions factory, so the textOnly parameter was always undefined. The ImageTextToText and ImageAudioTextToText factories include vision_encoder/audio_encoder when textOnly is falsy. - Add textOnly parameter to getSessionsConfig and forward it to the sessions factory - Export resolveTypeConfig for reuse - Compute textOnly in get_model_files using the same cross-architecture detection as from_pretrained (ForCausalLM loading a ForConditionalGeneration model) Closes huggingface#1606
|
Thanks for the PR 👍 cross-architecture loading detection is now duplicated in two placed. can you make sure that the model registry logic uses the same helper functions as defined in modeling utils? |
Per review feedback, move the cross-architecture detection into a shared isTextOnlyConfig() helper in modeling_utils.js instead of duplicating the logic in get_model_files.js.
|
Good point! I've extracted the cross-architecture detection into a shared |
| return nativeArch.endsWith('ForConditionalGeneration'); | ||
| } | ||
|
|
||
| export function getSessionsConfig(modelType, config, options = {}, textOnly = false) { |
There was a problem hiding this comment.
you also messed up the location of the isTextOnlyConfig function (it is in between getSessionsConfig and its jsdoc)
There was a problem hiding this comment.
this also duplicates logic between this and resolveTypeConfig
| // Use the shared helper to detect cross-architecture loading (e.g. | ||
| // ForCausalLM loading a ForConditionalGeneration model). In text-only | ||
| // mode the sessions factory skips vision/audio encoder files. |
There was a problem hiding this comment.
no need for these comments :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| * @returns {{ typeConfig: Object, textOnly: boolean, modelType: number|undefined }} | ||
| */ | ||
| function resolveTypeConfig(modelName, config) { | ||
| export function resolveTypeConfig(modelName, config) { |
There was a problem hiding this comment.
not used elsewhere? Probably meant to be exported and usage above.
…eration task inspired by #1608 Co-Authored-By: zxshen <zshen339@gatech.edu>
|
inspired by this PR, I opened #1614, which is a more robust fix of this problem. I added you as a co-author for the inspiration, so I'll close this PR |
…ration pipeline (#1614) * Add unit test for text-generation on multimodal model * add more multimodal text-generation unit tests * Exclude certain sessions when loading multimodal models with text-generation task inspired by #1608 Co-Authored-By: zxshen <zshen339@gatech.edu> * simplify multimodal text-generation pipeline logic * invert logic to keep non-model files * cleanup --------- Co-authored-by: zxshen <zshen339@gatech.edu>
Summary
is_pipeline_cached()returnsfalseafter successfully loading atext-generationpipeline for models likeonnx-community/gemma-3-4b-it-ONNXbecause it checks forvision_encoderONNX files that were never downloaded.Root Cause
getSessionsConfig()forwarded only 2 arguments to the sessions factory, so thetextOnlyparameter was alwaysundefined. TheImageTextToText/ImageAudioTextToTextsession factories includevision_encoder/audio_encoderwhentextOnlyis falsy — butfrom_pretrained()correctly detects cross-architecture loading and setstextOnly = trueto skip those files.textOnlycomputed?vision_encoderincluded for text-gen?from_pretrained(actual load)get_model_files(cache check)Fix
textOnlyparameter togetSessionsConfig()and forward it to the sessions factoryget_model_files(), detect cross-architecture loading (same logic asresolveTypeConfig) and passtextOnly = truewhen appropriateresolveTypeConfigfor reuseTest plan
is_pipeline_cached("text-generation", "onnx-community/gemma-3-4b-it-ONNX", { device: "webgpu", dtype: "q4f16" })now returnstrueafter the pipeline has been loadedCloses #1606