[Pallas] Fix attention example VMEM regression by making LSE 3D#2743
Open
norx1991 wants to merge 1 commit into
Open
[Pallas] Fix attention example VMEM regression by making LSE 3D#2743norx1991 wants to merge 1 commit into
norx1991 wants to merge 1 commit into
Conversation
6d6769f to
2ffe79e
Compare
Change `lse` from 2D `[B*H, S]` to 3D `[B*H, S, 1]`. External API is unchanged: caller still sees `lse` as `[B, H, S]` fp32 via a reshape at return, and the backward kernel flattens via `lse_in.reshape(-1)`, so it is unaffected. On Pallas/TPU, a 2D output whose leading dim is the outer tile forces an 8-element TPU sublane alignment on the corresponding block_size. Helion's adjust_block_size_constraints takes the max alignment across all tile-indexed tensors sharing that block_id, so the requirement inflates block_b for Q/K/V/out as well. At B=8 H=32 S=8192 D=256 bf16 with [2, 512, 2048] unroll pb=False, this expanded XLA's K and V input windows from bf16[2,8192,256] (16 MB each, double-buffered) to bf16[8,8192,256] (64 MB each), pushing total scoped VMEM past the 64 MiB cap. Putting the alignment-sensitive dim on a trailing size-1 dim sidesteps the inflation. The TODO in the new code points to the underlying compiler-level limitation, which can be addressed in a follow-up.
2ffe79e to
9f301a3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
lsefrom 2D[B*H, S]to 3D[B*H, S, 1].lseas[B, H, S]fp32 (reshape at return). Backward kernel already flattens vialse_in.reshape(-1), so it is unaffected.Why
On Pallas/TPU, a 2D output whose leading dim is the outer tile forces an 8-element TPU sublane alignment on the corresponding block_size. Helion's
adjust_block_size_constraintstakes the max alignment across all tile-indexed tensors sharing that block_id, so the requirement inflatesblock_bfor Q/K/V/out as well. AtB=8 H=32 S=8192 D=256bf16 withblock_sizes=[2, 512, 2048]unrollpb=False, this expanded XLA's K and V input windows frombf16[2,8192,256](16 MB each, double-buffered) tobf16[8,8192,256](64 MB each), pushing total scoped VMEM past the 64 MiB cap.Putting the alignment-sensitive dim on a trailing size-1 dim sidesteps the inflation. Verified by extracting
_block_spec_infofor both variants:_BLOCK_SIZE_0 = 8, K/V block_shape(8, None, None)→ 238 MB total_BLOCK_SIZE_0 = 2, K/V block_shape(2, None, None)→ 34 MB totalXLA allocator confirms the reduction matches.
Impact (B=8 H=32 S=8192 D=256 bf16, TPU v7x)
[2,512,2048] unroll pb=FalseThe underlying compiler-level limitation (
adjust_block_size_constraintsmax-propagating per-tensor TPU sublane alignment across independent tile-indexed tensors sharing a block_id) can be addressed in a follow-up.