Summary
scatter_(1, idx, v) and gather(1, idx) both build a 2-D [row, col] index where row is an iota over the indexed (token/batch) axis. When tt-mlir splits the program into per-device shapes, that sharded iota is shrunk to its local size without adding the per-shard offset — every shard emits 0..local-1 instead of shard_id*local + 0..local-1. So only mesh-row 0 is correct:
- scatter: writes only into local rows
0..local-1 → rows 1-3 zeroed.
- gather: table is all-gathered to full size, but the local iota reads rows
0..local-1 → rows 1-3 read the wrong rows.
This is the root cause behind GLM 4.7 being correct only for the first 16/64 users (#5409). Same class previously hit gpt-oss.
Where it goes wrong
lib/Dialect/StableHLO/Transforms/UpdateGlobalToLocalShapes.cpp, createNewOperationState() has special cases for ConstantOp/SliceOp/GatherOp but none for IotaOp, so a sharded iota only gets its shape shrunk. Verified localized gather IR (4×8 mesh, batch sharded on _axis_0):
%5 = stablehlo.iota dim = 0 : tensor<16xui32> // LOCAL 0..15 on every shard — no offset
%10 = "stablehlo.all_gather"(%2) ... -> tensor<64x256xbf16> // table gathered to all 64 rows
%11 = "stablehlo.gather"(%10, %9) ... // reads rows 0..15 -> wrong for shards 1-3
Repro
Branch: mvasiljevic/scatter-gather-sharded-index-repro
Test: tests/torch/graphs/test_moe_router_sharded_index.py
pytest -rA tests/torch/graphs/test_moe_router_sharded_index.py
For each of scatter_/gather it runs a *_buggy graph and its one_hot bypass, indices passed as deterministic inputs (topk on random bf16 disagrees TT-vs-CPU and would mask the bug). Verified on Galaxy, batch sharded 4 ways (16 users/row):
| op |
single-chip |
replicated |
batch-sharded |
one_hot bypass (sharded) |
| scatter |
PASS |
PASS |
FAIL (pcc 0.38) |
PASS |
| gather |
PASS |
PASS |
FAIL (pcc 0.24) |
PASS |
single-chip and replicated pass, so only sharding the indexed axis triggers it — pinning the bug to the sharded-iota localization. The *_buggy tests are marked xfail.
Fix
Add per-shard offset for a sharded iota in UpdateGlobalToLocalShapes (local_iota + device_index_along_axis * local_size). This needs a per-device index lowerable to TTNN (e.g. lower stablehlo.partition_id), which doesn't exist today — so it's a real feature, not a patch. (Gather alone could instead be routed to native ttnn.gather.)
Current bypass (model-level)
The MoE router (sparse_mlp.py::route_tokens_to_experts) replaces the ops so no iota indexes the sharded axis: scatter_→one_hot+any, gather→one_hot+einsum, with the arange over the unsharded expert/group dim. Restores GLM 4.7 PCC 0.88→0.993 on all 4 rows. Shipped in #5435.
Related: #5409, #5435
Summary
scatter_(1, idx, v)andgather(1, idx)both build a 2-D[row, col]index whererowis aniotaover the indexed (token/batch) axis. When tt-mlir splits the program into per-device shapes, that sharded iota is shrunk to its local size without adding the per-shard offset — every shard emits0..local-1instead ofshard_id*local + 0..local-1. So only mesh-row 0 is correct:0..local-1→ rows 1-3 zeroed.0..local-1→ rows 1-3 read the wrong rows.This is the root cause behind GLM 4.7 being correct only for the first 16/64 users (#5409). Same class previously hit gpt-oss.
Where it goes wrong
lib/Dialect/StableHLO/Transforms/UpdateGlobalToLocalShapes.cpp,createNewOperationState()has special cases forConstantOp/SliceOp/GatherOpbut none forIotaOp, so a sharded iota only gets its shape shrunk. Verified localized gather IR (4×8 mesh, batch sharded on_axis_0):Repro
Branch:
mvasiljevic/scatter-gather-sharded-index-reproTest:
tests/torch/graphs/test_moe_router_sharded_index.pyFor each of
scatter_/gatherit runs a*_buggygraph and itsone_hotbypass, indices passed as deterministic inputs (topkon random bf16 disagrees TT-vs-CPU and would mask the bug). Verified on Galaxy, batch sharded 4 ways (16 users/row):single-chip and replicated pass, so only sharding the indexed axis triggers it — pinning the bug to the sharded-iota localization. The
*_buggytests are markedxfail.Fix
Add per-shard offset for a sharded
iotainUpdateGlobalToLocalShapes(local_iota + device_index_along_axis * local_size). This needs a per-device index lowerable to TTNN (e.g. lowerstablehlo.partition_id), which doesn't exist today — so it's a real feature, not a patch. (Gather alone could instead be routed to nativettnn.gather.)Current bypass (model-level)
The MoE router (
sparse_mlp.py::route_tokens_to_experts) replaces the ops so no iota indexes the sharded axis:scatter_→one_hot+any,gather→one_hot+einsum, with thearangeover the unsharded expert/group dim. Restores GLM 4.7 PCC 0.88→0.993 on all 4 rows. Shipped in #5435.Related: #5409, #5435