-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[studio] Fix VLM detection for transformers v5 #4868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -481,8 +481,11 @@ def load_model_config( | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # VLM architecture suffixes and known VLM model_type values. | ||||||
| _VLM_ARCH_SUFFIXES = ("ForConditionalGeneration", "ForVisionText2Text") | ||||||
| # Known VLM model_type values. Used as a fallback when explicit vision signals | ||||||
| # (vision_config, img_processor, image_token_index, image_token_id) are not | ||||||
| # present in the config. In practice, all these models DO have explicit vision | ||||||
| # signals in their real configs; the model_type list provides a safety net for | ||||||
| # partial/mock configs. | ||||||
| _VLM_MODEL_TYPES = { | ||||||
| "phi3_v", | ||||||
| "llava", | ||||||
|
|
@@ -491,7 +494,29 @@ def load_model_config( | |||||
| "internvl_chat", | ||||||
| "cogvlm2", | ||||||
| "minicpmv", | ||||||
| "gemma3", | ||||||
| "gemma3n", | ||||||
| "gemma4", | ||||||
| "qwen2_vl", | ||||||
| "qwen2_5_vl", | ||||||
| "qwen3_5", | ||||||
| "qwen3_vl", | ||||||
| "qwen3_vl_moe", | ||||||
| "paligemma", | ||||||
| "pix2struct", | ||||||
| "video_llava", | ||||||
| "blip-2", | ||||||
| "blip_2", | ||||||
| "idefics2", | ||||||
| "idefics3", | ||||||
| "mllama", | ||||||
| "chameleon", | ||||||
| "xgenmm", | ||||||
| "smolvlm", | ||||||
| "molmo", | ||||||
| "fuyu", | ||||||
| } | ||||||
| _AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"} | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list of excluded model types should be expanded. The ForConditionalGeneration architecture suffix is used by many non-vision Seq2Seq models (such as T5, BART, Marian, etc.), which leads to false positives in VLM detection. Renaming this to a more general _NON_VLM_MODEL_TYPES and including common Seq2Seq families is recommended to improve detection accuracy.
Suggested change
|
||||||
|
|
||||||
| # Pre-computed .venv_t5 path and backend dir for subprocess version switching. | ||||||
| _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5") | ||||||
|
|
@@ -520,25 +545,35 @@ def load_model_config( | |||||
| kwargs["token"] = token | ||||||
| config = AutoConfig.from_pretrained(model_name, **kwargs) | ||||||
|
|
||||||
| model_type = getattr(config, "model_type", None) | ||||||
| audio_only_types = {"csm", "whisper"} | ||||||
|
|
||||||
| is_vlm = False | ||||||
| if hasattr(config, "architectures"): | ||||||
| is_vlm = any( | ||||||
| x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) | ||||||
| for x in config.architectures | ||||||
| ) | ||||||
| if not is_vlm and hasattr(config, "vision_config"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "img_processor"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "image_token_index"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "model_type"): | ||||||
| vlm_types = {"phi3_v","llava","llava_next","llava_onevision", | ||||||
| "internvl_chat","cogvlm2","minicpmv"} | ||||||
| if config.model_type in vlm_types: | ||||||
| if model_type not in audio_only_types: | ||||||
| if getattr(config, "vision_config", None) is not None: | ||||||
| is_vlm = True | ||||||
| if not is_vlm and getattr(config, "img_processor", None) is not None: | ||||||
| is_vlm = True | ||||||
| if not is_vlm and getattr(config, "image_token_index", None) is not None: | ||||||
| is_vlm = True | ||||||
| if not is_vlm and getattr(config, "image_token_id", None) is not None: | ||||||
| is_vlm = True | ||||||
| if not is_vlm and getattr(config, "architectures", None): | ||||||
| is_vlm = any( | ||||||
| x.endswith("ForVisionText2Text") | ||||||
| for x in config.architectures | ||||||
| ) | ||||||
| if not is_vlm and model_type in { | ||||||
| "phi3_v", "llava", "llava_next", "llava_onevision", | ||||||
| "internvl_chat", "cogvlm2", "minicpmv", "gemma3", "gemma3n", | ||||||
| "gemma4", "qwen2_vl", "qwen2_5_vl", "qwen3_5", "qwen3_vl", | ||||||
| "qwen3_vl_moe", "paligemma", "pix2struct", "video_llava", | ||||||
| "blip-2", "blip_2", "idefics2", "idefics3", "mllama", | ||||||
| "chameleon", "xgenmm", "smolvlm", "molmo", "fuyu", | ||||||
| }: | ||||||
| is_vlm = True | ||||||
|
|
||||||
| model_type = getattr(config, "model_type", "unknown") | ||||||
| model_type = model_type or "unknown" | ||||||
| archs = getattr(config, "architectures", []) | ||||||
| print(json.dumps({"is_vision": is_vlm, "model_type": model_type, | ||||||
| "architectures": archs})) | ||||||
|
|
@@ -617,6 +652,111 @@ def _is_vision_model_subprocess( | |||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _is_vlm_config(config: Any) -> bool: | ||||||
| if isinstance(config, dict): | ||||||
| model_type = config.get("model_type") | ||||||
| architectures = config.get("architectures") | ||||||
| vision_config = config.get("vision_config") | ||||||
| img_processor = config.get("img_processor") | ||||||
| image_token_index = config.get("image_token_index") | ||||||
| image_token_id = config.get("image_token_id") | ||||||
| else: | ||||||
| model_type = getattr(config, "model_type", None) | ||||||
| architectures = getattr(config, "architectures", None) | ||||||
| vision_config = getattr(config, "vision_config", None) | ||||||
| img_processor = getattr(config, "img_processor", None) | ||||||
| image_token_index = getattr(config, "image_token_index", None) | ||||||
| image_token_id = getattr(config, "image_token_id", None) | ||||||
|
|
||||||
| if model_type in _AUDIO_ONLY_MODEL_TYPES: | ||||||
| return False | ||||||
|
Comment on lines
+671
to
672
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| # Explicit vision signals are definitive (must be non-None to count) | ||||||
| if any( | ||||||
| v is not None | ||||||
| for v in (vision_config, img_processor, image_token_index, image_token_id) | ||||||
| ): | ||||||
| return True | ||||||
|
|
||||||
| # ForVisionText2Text is a VLM-specific architecture suffix | ||||||
| if architectures and any(x.endswith("ForVisionText2Text") for x in architectures): | ||||||
| return True | ||||||
|
Comment on lines
+682
to
+683
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the new raw- Useful? React with 👍 / 👎. |
||||||
|
|
||||||
| return model_type in _VLM_MODEL_TYPES | ||||||
|
|
||||||
|
|
||||||
| def _load_model_config_metadata( | ||||||
| model_name: str, hf_token: Optional[str] = None | ||||||
| ) -> Tuple[Optional[Dict[str, Any]], Optional[Exception]]: | ||||||
| """Load raw config.json as a plain dict, bypassing AutoConfig. | ||||||
|
|
||||||
| Returns (config_dict, None) on success, or (None, exception) on failure. | ||||||
| The exception is returned so callers can classify permanent vs transient. | ||||||
| """ | ||||||
| try: | ||||||
| if is_local_path(model_name): | ||||||
| config_path = Path(normalize_path(model_name)) / "config.json" | ||||||
| if config_path.is_file(): | ||||||
| return json.loads(config_path.read_text()), None | ||||||
| return None, None | ||||||
|
|
||||||
| from huggingface_hub import hf_hub_download | ||||||
|
|
||||||
| try: | ||||||
| resolved_name = resolve_cached_repo_id_case(model_name) | ||||||
| except Exception: | ||||||
| resolved_name = model_name | ||||||
|
|
||||||
| download_kwargs: Dict[str, Any] = {} | ||||||
| if hf_token: | ||||||
| download_kwargs["token"] = hf_token | ||||||
|
|
||||||
| config_path = hf_hub_download( | ||||||
| repo_id = resolved_name, | ||||||
| filename = "config.json", | ||||||
| **download_kwargs, | ||||||
| ) | ||||||
| return json.loads(Path(config_path).read_text()), None | ||||||
| except Exception as exc: | ||||||
| logger.warning("Could not load raw config metadata for %s: %s", model_name, exc) | ||||||
| return None, exc | ||||||
|
|
||||||
|
|
||||||
| def _classify_detection_error(exc: Exception) -> Optional[bool]: | ||||||
| """Classify a detection exception as permanent (False) or transient (None). | ||||||
|
|
||||||
| Permanent failures (model not found, gated, bad config) are safe to cache | ||||||
| as False. Transient failures (network, timeout) should not be cached so | ||||||
| they can be retried on the next call. | ||||||
| """ | ||||||
| try: | ||||||
| from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError | ||||||
| except ImportError: | ||||||
| try: | ||||||
| from huggingface_hub.utils import ( | ||||||
| RepositoryNotFoundError, | ||||||
| GatedRepoError, | ||||||
| ) | ||||||
| except ImportError: | ||||||
| RepositoryNotFoundError = GatedRepoError = None | ||||||
| # EntryNotFoundError means config.json doesn't exist in the repo -- permanent | ||||||
| try: | ||||||
| from huggingface_hub.errors import EntryNotFoundError, RevisionNotFoundError | ||||||
| except ImportError: | ||||||
| try: | ||||||
| from huggingface_hub.utils import EntryNotFoundError, RevisionNotFoundError | ||||||
| except ImportError: | ||||||
| EntryNotFoundError = RevisionNotFoundError = None | ||||||
| permanent_types = [ValueError, json.JSONDecodeError] | ||||||
| if RepositoryNotFoundError is not None: | ||||||
| permanent_types.extend([RepositoryNotFoundError, GatedRepoError]) | ||||||
| if EntryNotFoundError is not None: | ||||||
| permanent_types.extend([EntryNotFoundError, RevisionNotFoundError]) | ||||||
| if isinstance(exc, tuple(permanent_types)): | ||||||
| return False | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _token_fingerprint(token: Optional[str]) -> Optional[str]: | ||||||
| """Return a SHA256 digest of the token for use as a cache key. | ||||||
|
|
||||||
|
|
@@ -711,73 +851,38 @@ def _is_vision_model_uncached( | |||||
| "Model '%s' needs transformers 5.x -- checking vision via subprocess", | ||||||
| model_name, | ||||||
| ) | ||||||
| return _is_vision_model_subprocess(model_name, hf_token = hf_token) | ||||||
|
|
||||||
| try: | ||||||
| config = load_model_config(model_name, use_auth = True, token = hf_token) | ||||||
|
|
||||||
| # Exclude audio-only models that share ForConditionalGeneration suffix | ||||||
| # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration) | ||||||
| _audio_only_model_types = {"csm", "whisper"} | ||||||
| model_type = getattr(config, "model_type", None) | ||||||
| if model_type in _audio_only_model_types: | ||||||
| return False | ||||||
| subprocess_result = _is_vision_model_subprocess(model_name, hf_token = hf_token) | ||||||
| if subprocess_result is not None: | ||||||
| return subprocess_result | ||||||
|
|
||||||
| # Check 1: Architecture class name patterns | ||||||
| if hasattr(config, "architectures"): | ||||||
| is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) | ||||||
| # Subprocess failed (transient) -- fall back to raw config.json metadata | ||||||
| config_data, metadata_error = _load_model_config_metadata( | ||||||
| model_name, hf_token = hf_token | ||||||
| ) | ||||||
| if config_data is not None: | ||||||
| is_vlm = _is_vlm_config(config_data) | ||||||
| if is_vlm: | ||||||
| logger.info( | ||||||
| f"Model {model_name} detected as VLM: architecture {config.architectures}" | ||||||
| "Model %s detected as VLM from raw config metadata: " | ||||||
| "model_type=%s architectures=%s", | ||||||
| model_name, | ||||||
| config_data.get("model_type"), | ||||||
| config_data.get("architectures", []), | ||||||
| ) | ||||||
| return True | ||||||
|
|
||||||
| # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) | ||||||
| if hasattr(config, "vision_config"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has vision_config") | ||||||
| return True | ||||||
|
|
||||||
| # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) | ||||||
| if hasattr(config, "img_processor"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has img_processor") | ||||||
| return True | ||||||
|
|
||||||
| # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) | ||||||
| if hasattr(config, "image_token_index"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has image_token_index") | ||||||
| return True | ||||||
|
|
||||||
| # Check 5: Known VLM model_type values that may not match above checks | ||||||
| if hasattr(config, "model_type"): | ||||||
| if config.model_type in _VLM_MODEL_TYPES: | ||||||
| logger.info( | ||||||
| f"Model {model_name} detected as VLM: model_type={config.model_type}" | ||||||
| ) | ||||||
| return True | ||||||
| return is_vlm | ||||||
|
|
||||||
| return False | ||||||
| # Both subprocess and raw config failed -- classify the error. | ||||||
| # Permanent failures should be cached as False; transient as None. | ||||||
| if metadata_error is not None: | ||||||
| return _classify_detection_error(metadata_error) | ||||||
| return None | ||||||
|
|
||||||
| try: | ||||||
| config = load_model_config(model_name, use_auth = True, token = hf_token) | ||||||
| return _is_vlm_config(config) | ||||||
| except Exception as e: | ||||||
| logger.warning(f"Could not determine if {model_name} is vision model: {e}") | ||||||
| # Permanent failures (model not found, gated, bad config) should be | ||||||
| # cached as False. Transient failures (network, timeout) should not. | ||||||
| try: | ||||||
| from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError | ||||||
| except ImportError: | ||||||
| try: | ||||||
| from huggingface_hub.utils import ( | ||||||
| RepositoryNotFoundError, | ||||||
| GatedRepoError, | ||||||
| ) | ||||||
| except ImportError: | ||||||
| RepositoryNotFoundError = GatedRepoError = None | ||||||
| if RepositoryNotFoundError is not None and isinstance( | ||||||
| e, (RepositoryNotFoundError, GatedRepoError) | ||||||
| ): | ||||||
| return False | ||||||
| if isinstance(e, (ValueError, json.JSONDecodeError)): | ||||||
| return False | ||||||
| return None | ||||||
| return _classify_detection_error(e) | ||||||
|
|
||||||
|
|
||||||
| VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test name is ambiguous because 't5' typically refers to the T5 model family, whereas here it refers to models requiring Transformers v5. Renaming it to test_is_vision_model_falls_back_to_raw_metadata_for_v5_models would clarify that the test is about version-based fallbacks rather than the T5 architecture.