[GLM] Rework MoE routing to avoid scatter_/gather lowerings on multi-row mesh#5435
Open
mvasiljevicTT wants to merge 2 commits into
Open
[GLM] Rework MoE routing to avoid scatter_/gather lowerings on multi-row mesh#5435mvasiljevicTT wants to merge 2 commits into
mvasiljevicTT wants to merge 2 commits into
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #5435 +/- ##
=======================================
Coverage 33.80% 33.80%
=======================================
Files 37 37
Lines 4990 4990
=======================================
Hits 1687 1687
Misses 3303 3303 ☔ View full report in Codecov by Harness. |
route_tokens_to_experts used torch scatter_ (group mask) and gather (topk weights). On a sharded batch axis their StableHLO lowerings add a token row-iota / build a flat token*E+expert index that tt-mlir/Shardy mishandles (sharded iota loses its per-shard offset; gather->embedding flat index is computed at fp16 precision), silently zeroing/garbling the routed-MoE contribution for mesh-rows 1-3 (48/64 users). Replace both with the one_hot pattern already used for XLA compat (_topk_to_sparse_scores): group mask via one_hot+any, topk weights via einsum. Mathematically identical, but the arange is over experts/groups (not the batch axis) and the einsum keeps the expert index small and exact. GLM-4.7 (4,8) galaxy decode: all 4 mesh-rows now PCC 0.993 (was rows 1-3 0.856); full-batch decode PCC 0.9937. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
c72dabb to
48adda3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ticket
Fixes #5409
Problem description
GLM 4.7 had good PCC (0.99) only on the first 16/64 users — those on the first of the 4 devices the batch is sharded across — because the MoE router indexes a global tensor with a large index that is either mis-sharded or precision-truncated. Two bugs of this class, both in
route_tokens_to_experts(sparse_mlp.py):group_mask.scatter_(1, group_idx, 1)—scatter_adds a token row-iota that Shardy shards without the per-shard offset, so the group mask marks only mesh-row 0 and zeros routing for rows 1-3 (0.88 → 0.95).router_logits.gather(1, topk_indices)— lowers tottir.embeddingover the all-gathered[tokens·E]score table with a flat indextoken·E + expertcomputed at fp16-class precision. For rows 1-3 the largetoken·Eoffset rounds off the expert bits (e.g.2609 → 2610), gathering the wrong expert's score (0.95 → 0.99).What's changed
Replaces both ops with the
one_hotpattern already used for XLA compat (_topk_to_sparse_scores), so thearangeis over experts/groups (a model dim, not the sharded batch axis) and the index stays small + exact:scatter_→one_hot+any.gather→one_hot+einsum.Both rewrites are mathematically equivalent on CPU and just workarounds around buggy
scatterandgatherops to unblock glm pcc until they are fixed.Impact