Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions studio/backend/tests/test_transformers_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_tokenizer_class_cache,
needs_transformers_5,
)
from utils.models import model_config


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -188,3 +189,53 @@ def test_local_checkpoint_resolved_via_config(self, tmp_path: Path):
# We test the full resolution chain here:
resolved = _resolve_base_model(str(tmp_path))
assert needs_transformers_5(resolved) is True


class TestVisionModelDetection:
def test_is_vlm_config_accepts_transformers_objects(self):
config = _types.SimpleNamespace(
model_type = "qwen3_5",
architectures = ["Qwen3_5ForConditionalGeneration"],
vision_config = {},
)
assert model_config._is_vlm_config(config) is True

def test_is_vlm_config_accepts_raw_config_dicts(self):
config = {
"model_type": "gemma4",
"architectures": ["Gemma4ForConditionalGeneration"],
"vision_config": {},
}
assert model_config._is_vlm_config(config) is True

def test_is_vlm_config_rejects_audio_only_conditional_generation(self):
config = {
"model_type": "whisper",
"architectures": ["WhisperForConditionalGeneration"],
}
assert model_config._is_vlm_config(config) is False

def test_is_vision_model_falls_back_to_raw_metadata_for_t5_models(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test name is ambiguous because 't5' typically refers to the T5 model family, whereas here it refers to models requiring Transformers v5. Renaming it to test_is_vision_model_falls_back_to_raw_metadata_for_v5_models would clarify that the test is about version-based fallbacks rather than the T5 architecture.

Suggested change
def test_is_vision_model_falls_back_to_raw_metadata_for_t5_models(self):
def test_is_vision_model_falls_back_to_raw_metadata_for_v5_models(self):

with (
patch("utils.transformers_version.needs_transformers_5", return_value = True),
patch(
"utils.models.model_config.load_model_config",
side_effect = RuntimeError("direct load failed"),
),
patch(
"utils.models.model_config._is_vision_model_subprocess",
return_value = None,
),
patch(
"utils.models.model_config._load_model_config_metadata",
return_value = (
{
"model_type": "qwen3_5",
"architectures": ["Qwen3_5ForConditionalGeneration"],
"vision_config": {},
},
None,
),
),
):
assert model_config.is_vision_model("unsloth/Qwen3.5-4B") is True
259 changes: 182 additions & 77 deletions studio/backend/utils/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,11 @@ def load_model_config(
)


# VLM architecture suffixes and known VLM model_type values.
_VLM_ARCH_SUFFIXES = ("ForConditionalGeneration", "ForVisionText2Text")
# Known VLM model_type values. Used as a fallback when explicit vision signals
# (vision_config, img_processor, image_token_index, image_token_id) are not
# present in the config. In practice, all these models DO have explicit vision
# signals in their real configs; the model_type list provides a safety net for
# partial/mock configs.
_VLM_MODEL_TYPES = {
"phi3_v",
"llava",
Expand All @@ -491,7 +494,29 @@ def load_model_config(
"internvl_chat",
"cogvlm2",
"minicpmv",
"gemma3",
"gemma3n",
"gemma4",
"qwen2_vl",
"qwen2_5_vl",
"qwen3_5",
"qwen3_vl",
"qwen3_vl_moe",
"paligemma",
"pix2struct",
"video_llava",
"blip-2",
"blip_2",
"idefics2",
"idefics3",
"mllama",
"chameleon",
"xgenmm",
"smolvlm",
"molmo",
"fuyu",
}
_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of excluded model types should be expanded. The ForConditionalGeneration architecture suffix is used by many non-vision Seq2Seq models (such as T5, BART, Marian, etc.), which leads to false positives in VLM detection. Renaming this to a more general _NON_VLM_MODEL_TYPES and including common Seq2Seq families is recommended to improve detection accuracy.

Suggested change
_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"}
_NON_VLM_MODEL_TYPES = {"csm", "whisper", "t5", "bart", "marian", "pegasus", "blenderbot", "m2m_100"}


# Pre-computed .venv_t5 path and backend dir for subprocess version switching.
_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5")
Expand Down Expand Up @@ -520,25 +545,35 @@ def load_model_config(
kwargs["token"] = token
config = AutoConfig.from_pretrained(model_name, **kwargs)

model_type = getattr(config, "model_type", None)
audio_only_types = {"csm", "whisper"}

is_vlm = False
if hasattr(config, "architectures"):
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in config.architectures
)
if not is_vlm and hasattr(config, "vision_config"):
is_vlm = True
if not is_vlm and hasattr(config, "img_processor"):
is_vlm = True
if not is_vlm and hasattr(config, "image_token_index"):
is_vlm = True
if not is_vlm and hasattr(config, "model_type"):
vlm_types = {"phi3_v","llava","llava_next","llava_onevision",
"internvl_chat","cogvlm2","minicpmv"}
if config.model_type in vlm_types:
if model_type not in audio_only_types:
if getattr(config, "vision_config", None) is not None:
is_vlm = True
if not is_vlm and getattr(config, "img_processor", None) is not None:
is_vlm = True
if not is_vlm and getattr(config, "image_token_index", None) is not None:
is_vlm = True
if not is_vlm and getattr(config, "image_token_id", None) is not None:
is_vlm = True
if not is_vlm and getattr(config, "architectures", None):
is_vlm = any(
x.endswith("ForVisionText2Text")
for x in config.architectures
)
if not is_vlm and model_type in {
"phi3_v", "llava", "llava_next", "llava_onevision",
"internvl_chat", "cogvlm2", "minicpmv", "gemma3", "gemma3n",
"gemma4", "qwen2_vl", "qwen2_5_vl", "qwen3_5", "qwen3_vl",
"qwen3_vl_moe", "paligemma", "pix2struct", "video_llava",
"blip-2", "blip_2", "idefics2", "idefics3", "mllama",
"chameleon", "xgenmm", "smolvlm", "molmo", "fuyu",
}:
is_vlm = True

model_type = getattr(config, "model_type", "unknown")
model_type = model_type or "unknown"
archs = getattr(config, "architectures", [])
print(json.dumps({"is_vision": is_vlm, "model_type": model_type,
"architectures": archs}))
Expand Down Expand Up @@ -617,6 +652,111 @@ def _is_vision_model_subprocess(
return None


def _is_vlm_config(config: Any) -> bool:
if isinstance(config, dict):
model_type = config.get("model_type")
architectures = config.get("architectures")
vision_config = config.get("vision_config")
img_processor = config.get("img_processor")
image_token_index = config.get("image_token_index")
image_token_id = config.get("image_token_id")
else:
model_type = getattr(config, "model_type", None)
architectures = getattr(config, "architectures", None)
vision_config = getattr(config, "vision_config", None)
img_processor = getattr(config, "img_processor", None)
image_token_index = getattr(config, "image_token_index", None)
image_token_id = getattr(config, "image_token_id", None)

if model_type in _AUDIO_ONLY_MODEL_TYPES:
return False
Comment on lines +671 to 672
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the exclusion check to use the renamed and expanded list of non-VLM model types to prevent misidentifying standard Seq2Seq models as vision models.

Suggested change
if model_type in _AUDIO_ONLY_MODEL_TYPES:
return False
if model_type in _NON_VLM_MODEL_TYPES:
return False


# Explicit vision signals are definitive (must be non-None to count)
if any(
v is not None
for v in (vision_config, img_processor, image_token_index, image_token_id)
):
return True

# ForVisionText2Text is a VLM-specific architecture suffix
if architectures and any(x.endswith("ForVisionText2Text") for x in architectures):
return True
Comment on lines +682 to +683
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate architecture entries before suffix matching

In the new raw-config.json fallback path, _is_vlm_config assumes every architectures element is a string and directly calls x.endswith(...). If a custom/partial config contains a non-string entry (for example null), this raises AttributeError and bubbles out of is_vision_model because the transformers-5 fallback branch does not wrap _is_vlm_config in a try/except. That turns a recoverable detection miss into a hard failure (e.g., /models/config can return 500) instead of returning False/None.

Useful? React with 👍 / 👎.


return model_type in _VLM_MODEL_TYPES


def _load_model_config_metadata(
model_name: str, hf_token: Optional[str] = None
) -> Tuple[Optional[Dict[str, Any]], Optional[Exception]]:
"""Load raw config.json as a plain dict, bypassing AutoConfig.

Returns (config_dict, None) on success, or (None, exception) on failure.
The exception is returned so callers can classify permanent vs transient.
"""
try:
if is_local_path(model_name):
config_path = Path(normalize_path(model_name)) / "config.json"
if config_path.is_file():
return json.loads(config_path.read_text()), None
return None, None

from huggingface_hub import hf_hub_download

try:
resolved_name = resolve_cached_repo_id_case(model_name)
except Exception:
resolved_name = model_name

download_kwargs: Dict[str, Any] = {}
if hf_token:
download_kwargs["token"] = hf_token

config_path = hf_hub_download(
repo_id = resolved_name,
filename = "config.json",
**download_kwargs,
)
return json.loads(Path(config_path).read_text()), None
except Exception as exc:
logger.warning("Could not load raw config metadata for %s: %s", model_name, exc)
return None, exc


def _classify_detection_error(exc: Exception) -> Optional[bool]:
"""Classify a detection exception as permanent (False) or transient (None).

Permanent failures (model not found, gated, bad config) are safe to cache
as False. Transient failures (network, timeout) should not be cached so
they can be retried on the next call.
"""
try:
from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError
except ImportError:
try:
from huggingface_hub.utils import (
RepositoryNotFoundError,
GatedRepoError,
)
except ImportError:
RepositoryNotFoundError = GatedRepoError = None
# EntryNotFoundError means config.json doesn't exist in the repo -- permanent
try:
from huggingface_hub.errors import EntryNotFoundError, RevisionNotFoundError
except ImportError:
try:
from huggingface_hub.utils import EntryNotFoundError, RevisionNotFoundError
except ImportError:
EntryNotFoundError = RevisionNotFoundError = None
permanent_types = [ValueError, json.JSONDecodeError]
if RepositoryNotFoundError is not None:
permanent_types.extend([RepositoryNotFoundError, GatedRepoError])
if EntryNotFoundError is not None:
permanent_types.extend([EntryNotFoundError, RevisionNotFoundError])
if isinstance(exc, tuple(permanent_types)):
return False
return None


def _token_fingerprint(token: Optional[str]) -> Optional[str]:
"""Return a SHA256 digest of the token for use as a cache key.

Expand Down Expand Up @@ -711,73 +851,38 @@ def _is_vision_model_uncached(
"Model '%s' needs transformers 5.x -- checking vision via subprocess",
model_name,
)
return _is_vision_model_subprocess(model_name, hf_token = hf_token)

try:
config = load_model_config(model_name, use_auth = True, token = hf_token)

# Exclude audio-only models that share ForConditionalGeneration suffix
# (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration)
_audio_only_model_types = {"csm", "whisper"}
model_type = getattr(config, "model_type", None)
if model_type in _audio_only_model_types:
return False
subprocess_result = _is_vision_model_subprocess(model_name, hf_token = hf_token)
if subprocess_result is not None:
return subprocess_result

# Check 1: Architecture class name patterns
if hasattr(config, "architectures"):
is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)
# Subprocess failed (transient) -- fall back to raw config.json metadata
config_data, metadata_error = _load_model_config_metadata(
model_name, hf_token = hf_token
)
if config_data is not None:
is_vlm = _is_vlm_config(config_data)
if is_vlm:
logger.info(
f"Model {model_name} detected as VLM: architecture {config.architectures}"
"Model %s detected as VLM from raw config metadata: "
"model_type=%s architectures=%s",
model_name,
config_data.get("model_type"),
config_data.get("architectures", []),
)
return True

# Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)
if hasattr(config, "vision_config"):
logger.info(f"Model {model_name} detected as VLM: has vision_config")
return True

# Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)
if hasattr(config, "img_processor"):
logger.info(f"Model {model_name} detected as VLM: has img_processor")
return True

# Check 4: Has image_token_index (common in VLMs for image placeholder tokens)
if hasattr(config, "image_token_index"):
logger.info(f"Model {model_name} detected as VLM: has image_token_index")
return True

# Check 5: Known VLM model_type values that may not match above checks
if hasattr(config, "model_type"):
if config.model_type in _VLM_MODEL_TYPES:
logger.info(
f"Model {model_name} detected as VLM: model_type={config.model_type}"
)
return True
return is_vlm

return False
# Both subprocess and raw config failed -- classify the error.
# Permanent failures should be cached as False; transient as None.
if metadata_error is not None:
return _classify_detection_error(metadata_error)
return None

try:
config = load_model_config(model_name, use_auth = True, token = hf_token)
return _is_vlm_config(config)
except Exception as e:
logger.warning(f"Could not determine if {model_name} is vision model: {e}")
# Permanent failures (model not found, gated, bad config) should be
# cached as False. Transient failures (network, timeout) should not.
try:
from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError
except ImportError:
try:
from huggingface_hub.utils import (
RepositoryNotFoundError,
GatedRepoError,
)
except ImportError:
RepositoryNotFoundError = GatedRepoError = None
if RepositoryNotFoundError is not None and isinstance(
e, (RepositoryNotFoundError, GatedRepoError)
):
return False
if isinstance(e, (ValueError, json.JSONDecodeError)):
return False
return None
return _classify_detection_error(e)


VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm")
Expand Down