Skip to content

[GLM] Rework MoE routing to avoid scatter_/gather lowerings on multi-row mesh#5435

Open
mvasiljevicTT wants to merge 2 commits into
mainfrom
mvasiljevic/glm-moe-routing-fix
Open

[GLM] Rework MoE routing to avoid scatter_/gather lowerings on multi-row mesh#5435
mvasiljevicTT wants to merge 2 commits into
mainfrom
mvasiljevic/glm-moe-routing-fix

Conversation

@mvasiljevicTT

Copy link
Copy Markdown
Contributor

Ticket

Fixes #5409

Problem description

GLM 4.7 had good PCC (0.99) only on the first 16/64 users — those on the first of the 4 devices the batch is sharded across — because the MoE router indexes a global tensor with a large index that is either mis-sharded or precision-truncated. Two bugs of this class, both in route_tokens_to_experts (sparse_mlp.py):

  1. group_mask.scatter_(1, group_idx, 1)scatter_ adds a token row-iota that Shardy shards without the per-shard offset, so the group mask marks only mesh-row 0 and zeros routing for rows 1-3 (0.88 → 0.95).
  2. router_logits.gather(1, topk_indices) — lowers to ttir.embedding over the all-gathered [tokens·E] score table with a flat index token·E + expert computed at fp16-class precision. For rows 1-3 the large token·E offset rounds off the expert bits (e.g. 2609 → 2610), gathering the wrong expert's score (0.95 → 0.99).

What's changed

Replaces both ops with the one_hot pattern already used for XLA compat (_topk_to_sparse_scores), so the arange is over experts/groups (a model dim, not the sharded batch axis) and the index stays small + exact:

  • Group mask: scatter_one_hot + any.
  • Top-k weights: gatherone_hot + einsum.

Both rewrites are mathematically equivalent on CPU and just workarounds around buggy scatter and gather ops to unblock glm pcc until they are fixed.

Impact

@codecov-commenter

codecov-commenter commented Jun 30, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 33.80%. Comparing base (1bd6e3b) to head (48adda3).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #5435   +/-   ##
=======================================
  Coverage   33.80%   33.80%           
=======================================
  Files          37       37           
  Lines        4990     4990           
=======================================
  Hits         1687     1687           
  Misses       3303     3303           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

mvasiljevicTT and others added 2 commits July 1, 2026 10:58
route_tokens_to_experts used torch scatter_ (group mask) and gather (topk
weights). On a sharded batch axis their StableHLO lowerings add a token
row-iota / build a flat token*E+expert index that tt-mlir/Shardy mishandles
(sharded iota loses its per-shard offset; gather->embedding flat index is
computed at fp16 precision), silently zeroing/garbling the routed-MoE
contribution for mesh-rows 1-3 (48/64 users).

Replace both with the one_hot pattern already used for XLA compat
(_topk_to_sparse_scores): group mask via one_hot+any, topk weights via
einsum. Mathematically identical, but the arange is over experts/groups
(not the batch axis) and the einsum keeps the expert index small and exact.

GLM-4.7 (4,8) galaxy decode: all 4 mesh-rows now PCC 0.993 (was rows 1-3
0.856); full-batch decode PCC 0.9937.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@mvasiljevicTT mvasiljevicTT force-pushed the mvasiljevic/glm-moe-routing-fix branch from c72dabb to 48adda3 Compare July 1, 2026 08:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Benchmark] GLM 4.7 bad pcc for users on non-zero device

2 participants