[diffusiongemma] Pass decoder_input_ids to skip the model's randint#790
Conversation
d55e6c6 to
c8f2adb
Compare
| for key in list(inputs): | ||
| value = inputs[key].repeat_interleave(batch_size, dim=0) | ||
| inputs[key] = cast_input_to_type(value, dtype_override) | ||
| # Workaround: pass decoder_input_ids to skip the model's randint (its |
There was a problem hiding this comment.
Why not push for an issue fix instead of doing a workaround at the input level?
Also, is the model passing with this workaround enabled?
There was a problem hiding this comment.
Why not push for an issue fix instead of doing a workaround at the input level?
Agreed the proper fix belongs in tt-metal — randint's lowering hits an unsupported remainder on UINT32 (tracked in tenstorrent/tt-metal#27621, and noted in the linked ticket). But even with a tt-metal PR, we'd still have to wait for the next uplift, and since this is a priority model I added this input-level workaround to unblock it in the meantime; it's removable once the fix lands. Since it's listed (still unchecked) in that tracker, to avoid duplicating effort I'll sync with the issue owner to confirm there's no ongoing work, and raise the PR today if it's not already being picked up.
Also, is the model passing with this workaround enabled?
Not yet. With this workaround the model compiles and fails at runtime with a PCC=0.96, as mentioned in the description. So this clears the crash - jun30_diffgemma_after_work_around.log.zip
There was a problem hiding this comment.
Ok, instead of here, can we add MLIR workaround with link to the metal issue?
There was a problem hiding this comment.
- Update on the proper tt-metal fix: synced with @KalaivaniMCW — @mcw-anasuya is now working on UINT32
remaindersupport (Improved UINT32 support for eltwise OPs tt-metal#27621). - This particular workaround skips the model's
randintentirely, which can only happen at the model-input level — by the time it reaches tt-mlir,randintis already lowered torng_bit_generator+remainder, so MLIR can't un-call it. - A tt-mlir-side workaround would instead have to handle the uint32
remainderop itself (e.g. rewrite modulo-by-power-of-2 via shift/subtract) - Since Anasuya is working on proper uint32
remaindersupport in tt-metal (#27621), I'd lean toward keeping this input-level bridge (removable as soon as that lands) over a temporary MLIR rewrite. wdyt?
There was a problem hiding this comment.
Ok, sounds good. Let's also track the issue link above the workaround. Approving the PR :))
c8f2adb to
e4cc7c9
Compare
e4cc7c9 to
e81cfae
Compare
Ticket
Problem description
torch.randintwhendecoder_input_ids is None. On device,randintlowers torng_bit_generator+ an unsignedremainder(bits % range), and tt-metal doesn't supportremainderonUINT32→ the model crashes withUnsupported data type for remainder DataType::UINT32.What's changed
load_inputsnow passesdecoder_input_ids(host-sidetorch.randint, same shape/range/dtype as the model's own init), so the model skips its internal device-siderandintand noremainderop is emitted. The model now compiles and runs end-to-end on 8 chips. This is just a quick workaround for the tt-metal gap (UINT32remainder); removable once supported.PCC=0.96Checklist
Logs