Skip to content

Sharded stablehlo.iota localized without per-shard offset (breaks scatter & gather on non-zero mesh-rows) #5469

Description

@mvasiljevicTT

Summary

scatter_(1, idx, v) and gather(1, idx) both build a 2-D [row, col] index where row is an iota over the indexed (token/batch) axis. When tt-mlir splits the program into per-device shapes, that sharded iota is shrunk to its local size without adding the per-shard offset — every shard emits 0..local-1 instead of shard_id*local + 0..local-1. So only mesh-row 0 is correct:

  • scatter: writes only into local rows 0..local-1 → rows 1-3 zeroed.
  • gather: table is all-gathered to full size, but the local iota reads rows 0..local-1 → rows 1-3 read the wrong rows.

This is the root cause behind GLM 4.7 being correct only for the first 16/64 users (#5409). Same class previously hit gpt-oss.

Where it goes wrong

lib/Dialect/StableHLO/Transforms/UpdateGlobalToLocalShapes.cpp, createNewOperationState() has special cases for ConstantOp/SliceOp/GatherOp but none for IotaOp, so a sharded iota only gets its shape shrunk. Verified localized gather IR (4×8 mesh, batch sharded on _axis_0):

%5  = stablehlo.iota dim = 0 : tensor<16xui32>              // LOCAL 0..15 on every shard — no offset
%10 = "stablehlo.all_gather"(%2) ... -> tensor<64x256xbf16> // table gathered to all 64 rows
%11 = "stablehlo.gather"(%10, %9) ...                       // reads rows 0..15 -> wrong for shards 1-3

Repro

Branch: mvasiljevic/scatter-gather-sharded-index-repro
Test: tests/torch/graphs/test_moe_router_sharded_index.py

pytest -rA tests/torch/graphs/test_moe_router_sharded_index.py

For each of scatter_/gather it runs a *_buggy graph and its one_hot bypass, indices passed as deterministic inputs (topk on random bf16 disagrees TT-vs-CPU and would mask the bug). Verified on Galaxy, batch sharded 4 ways (16 users/row):

op single-chip replicated batch-sharded one_hot bypass (sharded)
scatter PASS PASS FAIL (pcc 0.38) PASS
gather PASS PASS FAIL (pcc 0.24) PASS

single-chip and replicated pass, so only sharding the indexed axis triggers it — pinning the bug to the sharded-iota localization. The *_buggy tests are marked xfail.

Fix

Add per-shard offset for a sharded iota in UpdateGlobalToLocalShapes (local_iota + device_index_along_axis * local_size). This needs a per-device index lowerable to TTNN (e.g. lower stablehlo.partition_id), which doesn't exist today — so it's a real feature, not a patch. (Gather alone could instead be routed to native ttnn.gather.)

Current bypass (model-level)

The MoE router (sparse_mlp.py::route_tokens_to_experts) replaces the ops so no iota indexes the sharded axis: scatter_one_hot+any, gatherone_hot+einsum, with the arange over the unsharded expert/group dim. Restores GLM 4.7 PCC 0.88→0.993 on all 4 rows. Shipped in #5435.

Related: #5409, #5435

Metadata

Metadata

Assignees

Labels

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions