Skip to content

[diffusiongemma] Pass decoder_input_ids to skip the model's randint#790

Merged
kamalrajkannan78 merged 1 commit into
mainfrom
kkannan/jun30_diffgemma_unit32_workaround
Jul 1, 2026
Merged

[diffusiongemma] Pass decoder_input_ids to skip the model's randint#790
kamalrajkannan78 merged 1 commit into
mainfrom
kkannan/jun30_diffgemma_unit32_workaround

Conversation

@kamalrajkannan78

@kamalrajkannan78 kamalrajkannan78 commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Ticket

Problem description

  • DiffusionGemma's forward initializes the decoder canvas with torch.randint when decoder_input_ids is None. On device, randint lowers to rng_bit_generator + an unsigned remainder (bits % range), and tt-metal doesn't support remainder on UINT32 → the model crashes with Unsupported data type for remainder DataType::UINT32.

What's changed

  • load_inputs now passes decoder_input_ids (host-side torch.randint, same shape/range/dtype as the model's own init), so the model skips its internal device-side randint and no remainder op is emitted. The model now compiles and runs end-to-end on 8 chips. This is just a quick workaround for the tt-metal gap (UINT32 remainder); removable once supported.
  • With this fix and tt-xla#5424, models now fails at runtime with PCC=0.96

Checklist

  • Verify the changes through local testing in WH

Logs

@kamalrajkannan78 kamalrajkannan78 force-pushed the kkannan/jun30_diffgemma_unit32_workaround branch from d55e6c6 to c8f2adb Compare June 30, 2026 17:13
@kamalrajkannan78 kamalrajkannan78 marked this pull request as ready for review June 30, 2026 17:31
for key in list(inputs):
value = inputs[key].repeat_interleave(batch_size, dim=0)
inputs[key] = cast_input_to_type(value, dtype_override)
# Workaround: pass decoder_input_ids to skip the model's randint (its

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not push for an issue fix instead of doing a workaround at the input level?

Also, is the model passing with this workaround enabled?

@kamalrajkannan78 kamalrajkannan78 Jul 1, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not push for an issue fix instead of doing a workaround at the input level?

Agreed the proper fix belongs in tt-metal — randint's lowering hits an unsupported remainder on UINT32 (tracked in tenstorrent/tt-metal#27621, and noted in the linked ticket). But even with a tt-metal PR, we'd still have to wait for the next uplift, and since this is a priority model I added this input-level workaround to unblock it in the meantime; it's removable once the fix lands. Since it's listed (still unchecked) in that tracker, to avoid duplicating effort I'll sync with the issue owner to confirm there's no ongoing work, and raise the PR today if it's not already being picked up.

Also, is the model passing with this workaround enabled?

Not yet. With this workaround the model compiles and fails at runtime with a PCC=0.96, as mentioned in the description. So this clears the crash - jun30_diffgemma_after_work_around.log.zip

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, instead of here, can we add MLIR workaround with link to the metal issue?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Update on the proper tt-metal fix: synced with @KalaivaniMCW@mcw-anasuya is now working on UINT32 remainder support (Improved UINT32 support for eltwise OPs tt-metal#27621).
  • This particular workaround skips the model's randint entirely, which can only happen at the model-input level — by the time it reaches tt-mlir, randint is already lowered to rng_bit_generator + remainder, so MLIR can't un-call it.
  • A tt-mlir-side workaround would instead have to handle the uint32 remainder op itself (e.g. rewrite modulo-by-power-of-2 via shift/subtract)
  • Since Anasuya is working on proper uint32 remainder support in tt-metal (#27621), I'd lean toward keeping this input-level bridge (removable as soon as that lands) over a temporary MLIR rewrite. wdyt?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good. Let's also track the issue link above the workaround. Approving the PR :))

@kamalrajkannan78 kamalrajkannan78 force-pushed the kkannan/jun30_diffgemma_unit32_workaround branch from e4cc7c9 to e81cfae Compare July 1, 2026 11:36
@kamalrajkannan78 kamalrajkannan78 merged commit 96ac61a into main Jul 1, 2026
2 checks passed
@kamalrajkannan78 kamalrajkannan78 deleted the kkannan/jun30_diffgemma_unit32_workaround branch July 1, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants