[ROCm] Fix dynamic lm_head INT8 applying to non-lm_head embeddings#1016
Merged
Conversation
d88994d to
7d553a1
Compare
mgehre-amd
approved these changes
Jun 22, 2026
mgehre-amd
reviewed
Jun 22, 2026
Comment on lines
+251
to
+256
| return leaf in ( | ||
| "embed_tokens", | ||
| "wte", | ||
| "word_embeddings", | ||
| "embedding", | ||
| "", |
There was a problem hiding this comment.
Have you checked that this covers all models that use int8 lm_head in our regression suite?
Author
There was a problem hiding this comment.
@mgehre-amd, the logic was refactored to try and be more robust can you re-review?
Yes. Tested all 5 model families that use dyn-lm-int8 locally on Strix Halo — all pass (Qwen3.5-4B, Qwen3.5-2B, Qwen3-VL-2B, Gemma4 E4B, Gemma4 E2B). The remaining models (Qwen3.5-35B, Qwen3.6-35B, Qwen3-Omni) use the same architectures.
_is_main_embedding() checks embedding_dim == config.hidden_size to distinguish the primary embedding from auxiliary ones like Gemma4's PLE (same vocab size, different dim). Qwen3-Omni has tie_word_embeddings=False so the guard is never called.
Two issues prevented --dynamic-lm-head-quantization int8 on Gemma 4: 1. Gemma 4 Per-Layer Embeddings (embed_tokens_per_layer) are VocabParallelEmbedding instances that were incorrectly identified as lm_head candidates via _has_tied_embeddings(). Fix: add _is_main_embedding() guard that identifies the primary embedding by checking embedding_dim == config.hidden_size (a structural invariant of transformer LMs, not a naming convention). 2. Gemma 4 model code reads embed_tokens.weight.dtype for tensor creation (buffer allocation), which returns int8 after quantization. Fix: cache dtype from vllm_config.model_config.dtype in Gemma4Model and use self.dtype instead of reading from the weight tensor. Same fix applied to gemma4_mm.py. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
7d553a1 to
0c2e322
Compare
mgehre-amd
approved these changes
Jun 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix
--dynamic-lm-head-quantization int8crash on Gemma 4 models.Problem
Two issues prevented dynamic lm_head INT8 from working on Gemma 4:
Per-Layer Embeddings misidentified as lm_head. Gemma 4 has
embed_tokens_per_layer(a secondVocabParallelEmbeddingwith dim=8960). Whentie_word_embeddings=True,_has_tied_embeddings()returnedTruefor ALLVocabParallelEmbeddinginstances, soDynamicInt8LMHeadMethodwas applied to the PLE layer. This causedRuntimeError: expected mat1 and mat2 to have the same dtype, but got: signed char != c10::Half.Model reads
embed_tokens.weight.dtypefor buffer allocation.gemma4.pylines 1111/1124 andgemma4_mm.py(4 instances) useself.embed_tokens.weight.dtypeto determine the compute dtype when creating tensors. After INT8 quantization,.weight.dtypereturnstorch.int8instead oftorch.float16, corrupting downstream tensor creation.Fix
vocab_parallel_embedding.py: Add_is_main_embedding()guard that identifies the primary token embedding by checkingembedding_dim == config.hidden_size. This is a structural invariant of transformer LMs — the primary embedding always projects into the hidden dimension. Auxiliary embeddings like PLE share the same vocab size but use a different dimension (num_layers * per_layer_dim), so vocab size alone doesn't distinguish them. The guard is documented with its assumptions and failure modes.gemma4.py: Cacheself.dtype = vllm_config.model_config.dtypeand replaceself.embed_tokens.weight.dtypewithself.dtype. This is the authoritative compute dtype and is stable under weight quantization.gemma4_mm.py: Sameembed_tokens.weight.dtype->self.language_model.model.dtypereplacement (4 instances).Testing
Tested locally on Strix Halo (gfx1151) with
cyankiwi/gemma-4-E2B-it-AWQ-INT4:--dynamic-lm-head-quantization int8Note
gemma3n.pyhas the sameembed_tokens.weight.dtypepattern (6 instances) and would need the same fix for dynamic lm_head INT8 support. Not included in this PR since Gemma 3N is not in our regression matrix.