Skip to content

[ROCm] Fix dynamic lm_head INT8 applying to non-lm_head embeddings#1016

Merged
mgehre-amd merged 1 commit into
gfx11from
marcusr/fix-gemma4-dyn-lm-int8
Jun 23, 2026
Merged

[ROCm] Fix dynamic lm_head INT8 applying to non-lm_head embeddings#1016
mgehre-amd merged 1 commit into
gfx11from
marcusr/fix-gemma4-dyn-lm-int8

Conversation

@marcusr-amd

@marcusr-amd marcusr-amd commented Jun 22, 2026

Copy link
Copy Markdown

Summary

Fix --dynamic-lm-head-quantization int8 crash on Gemma 4 models.

Problem

Two issues prevented dynamic lm_head INT8 from working on Gemma 4:

  1. Per-Layer Embeddings misidentified as lm_head. Gemma 4 has embed_tokens_per_layer (a second VocabParallelEmbedding with dim=8960). When tie_word_embeddings=True, _has_tied_embeddings() returned True for ALL VocabParallelEmbedding instances, so DynamicInt8LMHeadMethod was applied to the PLE layer. This caused RuntimeError: expected mat1 and mat2 to have the same dtype, but got: signed char != c10::Half.

  2. Model reads embed_tokens.weight.dtype for buffer allocation. gemma4.py lines 1111/1124 and gemma4_mm.py (4 instances) use self.embed_tokens.weight.dtype to determine the compute dtype when creating tensors. After INT8 quantization, .weight.dtype returns torch.int8 instead of torch.float16, corrupting downstream tensor creation.

Fix

vocab_parallel_embedding.py: Add _is_main_embedding() guard that identifies the primary token embedding by checking embedding_dim == config.hidden_size. This is a structural invariant of transformer LMs — the primary embedding always projects into the hidden dimension. Auxiliary embeddings like PLE share the same vocab size but use a different dimension (num_layers * per_layer_dim), so vocab size alone doesn't distinguish them. The guard is documented with its assumptions and failure modes.

gemma4.py: Cache self.dtype = vllm_config.model_config.dtype and replace self.embed_tokens.weight.dtype with self.dtype. This is the authoritative compute dtype and is stable under weight quantization.

gemma4_mm.py: Same embed_tokens.weight.dtype -> self.language_model.model.dtype replacement (4 instances).

Testing

Tested locally on Strix Halo (gfx1151) with cyankiwi/gemma-4-E2B-it-AWQ-INT4:

TTFT (ms) TPOT (ms) Decode (tok/s)
Baseline (no dyn-lm-int8) 596 11.31 88.4
With --dynamic-lm-head-quantization int8 599 9.91 100.9 (+14%)
  • 10/10 prompts completed, 0 failures
  • PLE correctly excluded (only main embedding gets int8)
  • Qwen3.5-4B/2B and Qwen3-VL-2B dyn-lm-int8 verified unaffected (CI run #2789651)

Note

gemma3n.py has the same embed_tokens.weight.dtype pattern (6 instances) and would need the same fix for dynamic lm_head INT8 support. Not included in this PR since Gemma 3N is not in our regression matrix.

@marcusr-amd marcusr-amd force-pushed the marcusr/fix-gemma4-dyn-lm-int8 branch 4 times, most recently from d88994d to 7d553a1 Compare June 22, 2026 20:35
@marcusr-amd marcusr-amd requested a review from mgehre-amd June 22, 2026 21:12
@marcusr-amd marcusr-amd marked this pull request as ready for review June 22, 2026 21:14
Comment on lines +251 to +256
return leaf in (
"embed_tokens",
"wte",
"word_embeddings",
"embedding",
"",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked that this covers all models that use int8 lm_head in our regression suite?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgehre-amd, the logic was refactored to try and be more robust can you re-review?

Yes. Tested all 5 model families that use dyn-lm-int8 locally on Strix Halo — all pass (Qwen3.5-4B, Qwen3.5-2B, Qwen3-VL-2B, Gemma4 E4B, Gemma4 E2B). The remaining models (Qwen3.5-35B, Qwen3.6-35B, Qwen3-Omni) use the same architectures.

_is_main_embedding() checks embedding_dim == config.hidden_size to distinguish the primary embedding from auxiliary ones like Gemma4's PLE (same vocab size, different dim). Qwen3-Omni has tie_word_embeddings=False so the guard is never called.

Two issues prevented --dynamic-lm-head-quantization int8 on Gemma 4:

1. Gemma 4 Per-Layer Embeddings (embed_tokens_per_layer) are
   VocabParallelEmbedding instances that were incorrectly identified as
   lm_head candidates via _has_tied_embeddings(). Fix: add
   _is_main_embedding() guard that identifies the primary embedding by
   checking embedding_dim == config.hidden_size (a structural invariant
   of transformer LMs, not a naming convention).

2. Gemma 4 model code reads embed_tokens.weight.dtype for tensor
   creation (buffer allocation), which returns int8 after quantization.
   Fix: cache dtype from vllm_config.model_config.dtype in Gemma4Model
   and use self.dtype instead of reading from the weight tensor.
   Same fix applied to gemma4_mm.py.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@marcusr-amd marcusr-amd force-pushed the marcusr/fix-gemma4-dyn-lm-int8 branch from 7d553a1 to 0c2e322 Compare June 23, 2026 14:10
@marcusr-amd marcusr-amd requested a review from mgehre-amd June 23, 2026 18:09
@mgehre-amd mgehre-amd merged commit 3709638 into gfx11 Jun 23, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants