Skip to content

[Pallas] Fix attention example VMEM regression by making LSE 3D#2743

Open
norx1991 wants to merge 1 commit into
mainfrom
yifeixu/attention-3d-lse
Open

[Pallas] Fix attention example VMEM regression by making LSE 3D#2743
norx1991 wants to merge 1 commit into
mainfrom
yifeixu/attention-3d-lse

Conversation

@norx1991

@norx1991 norx1991 commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Change attention example's lse from 2D [B*H, S] to 3D [B*H, S, 1].
  • External API unchanged: caller still sees lse as [B, H, S] fp32 (reshape at return). Backward kernel already flattens via lse_in.reshape(-1), so it is unaffected.

Why

On Pallas/TPU, a 2D output whose leading dim is the outer tile forces an 8-element TPU sublane alignment on the corresponding block_size. Helion's adjust_block_size_constraints takes the max alignment across all tile-indexed tensors sharing that block_id, so the requirement inflates block_b for Q/K/V/out as well. At B=8 H=32 S=8192 D=256 bf16 with block_sizes=[2, 512, 2048] unroll pb=False, this expanded XLA's K and V input windows from bf16[2,8192,256] (16 MB each, double-buffered) to bf16[8,8192,256] (64 MB each), pushing total scoped VMEM past the 64 MiB cap.

Putting the alignment-sensitive dim on a trailing size-1 dim sidesteps the inflation. Verified by extracting _block_spec_info for both variants:

  • 2D LSE: _BLOCK_SIZE_0 = 8, K/V block_shape (8, None, None) → 238 MB total
  • 3D LSE: _BLOCK_SIZE_0 = 2, K/V block_shape (2, None, None) → 34 MB total

XLA allocator confirms the reduction matches.

Impact (B=8 H=32 S=8192 D=256 bf16, TPU v7x)

Config Before this PR After this PR
[2,512,2048] unroll pb=False fails to compile (238 MB needed, 64 MB cap) 24.72 ms median

The underlying compiler-level limitation (adjust_block_size_constraints max-propagating per-tensor TPU sublane alignment across independent tile-indexed tensors sharing a block_id) can be addressed in a follow-up.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@norx1991 norx1991 force-pushed the yifeixu/attention-3d-lse branch from 6d6769f to 2ffe79e Compare June 10, 2026 20:39
Change `lse` from 2D `[B*H, S]` to 3D `[B*H, S, 1]`. External API is
unchanged: caller still sees `lse` as `[B, H, S]` fp32 via a reshape
at return, and the backward kernel flattens via `lse_in.reshape(-1)`,
so it is unaffected.

On Pallas/TPU, a 2D output whose leading dim is the outer tile forces
an 8-element TPU sublane alignment on the corresponding block_size.
Helion's adjust_block_size_constraints takes the max alignment across
all tile-indexed tensors sharing that block_id, so the requirement
inflates block_b for Q/K/V/out as well. At B=8 H=32 S=8192 D=256 bf16
with [2, 512, 2048] unroll pb=False, this expanded XLA's K and V
input windows from bf16[2,8192,256] (16 MB each, double-buffered) to
bf16[8,8192,256] (64 MB each), pushing total scoped VMEM past the
64 MiB cap.

Putting the alignment-sensitive dim on a trailing size-1 dim sidesteps
the inflation. The TODO in the new code points to the underlying
compiler-level limitation, which can be addressed in a follow-up.
@norx1991 norx1991 force-pushed the yifeixu/attention-3d-lse branch from 2ffe79e to 9f301a3 Compare June 10, 2026 23:22
@norx1991 norx1991 requested review from AmesingFlank and ethche June 11, 2026 20:47
@norx1991 norx1991 marked this pull request as ready for review June 11, 2026 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant