Skip to content

[Pallas] Defer tile load masks past transposes onto the sublane axis#2812

Merged
ethche merged 2 commits into
mainfrom
pallas-defer-load-mask
Jun 19, 2026
Merged

[Pallas] Defer tile load masks past transposes onto the sublane axis#2812
ethche merged 2 commits into
mainfrom
pallas-defer-load-mask

Conversation

@ethche

@ethche ethche commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

A Pallas tiled load masks out-of-bounds lanes multiplicatively in the load's own layout (ref[idx] * mask). When the value is only relayouted by an axis permutation (transpose / .T) before a dot or reduction, that consumer already inserts _mask_to(x, 0), which re-materializes the mask in the consumer layout.

In general, we see that masking on the sublane axis is much cheaper, while masking on a major axis adds a lot of vector load overhead. This informs our optimization: we defer loads only when the masked axis is a major/outer dim at the load (eager mask is expensive there) and the permutation moves it into the last-two (sublane/lane) dims at the consumer (mask is cheap there). We don't defer for other types of ops (view reshape etc). If the permutation moves the mask from sublane/lane to major axis, deferring can make things slower, so this gate is important.

Before:

  mask_1 = offset < tile_start + tile_extent
  load_4 = q[pl.ds(0, BLOCK), :, :] * mask_1.astype(jnp.bfloat16)[:, None, None]
  qbk    = jnp.transpose(load_4, [1, 0, 2])
  scores = lax.dot_general(qbk, k_t, ...)

After:

  mask_1  = offset < tile_start + tile_extent
  load_4  = q[pl.ds(0, BLOCK), :, :]
  qbk     = jnp.transpose(load_4, [1, 0, 2])                      
  _mask_to = jnp.where(mask_1.astype(jnp.float32)[None, :, None], qbk, 0)   # mask in consumer layout
  scores  = lax.dot_general(_mask_to, k_t, ...)

Performance on jagged gdpa dense kv (bf16). same shapes as in:#2782

shape dist block baseline TFLOP/s with fix TFLOP/s speedup
A (B=256) uniform q4096 200 257 1.29×
A ramp q1024 142 158 1.11×
A random q4096 150 189 1.27×
B (B=768) uniform q4096 213 281 1.32×
B ramp q1024 156 176 1.13×
B random q4096 159 210 1.32×

New FX graph pass defer_pallas_load_masks (Pallas-only, run just before remove_unnecessary_masking). For each load whose masked tile dim is provably re-masked downstream, it records the deferred block ids and marks the load's masked value unknown so the _mask_to survives; load codegen then skips the eager mask for those dims and the mask is materialized after the transpose.

Correctness rests on a dataflow proof: every use of the load reaches a mask_to(, 0) crossing only pure axis-permutation ops (aten.permute.default), with the masked dim surviving as a unique tile dim at each step. Only gated to pallas + major to sublane/lane permutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 18, 2026
@ethche ethche requested review from AmesingFlank, cota and thcmbs June 18, 2026 16:30
@meta-codesync

meta-codesync Bot commented Jun 18, 2026

Copy link
Copy Markdown

@ethche has imported this pull request. If you are a Meta employee, you can view this in D109034842.

@ethche ethche changed the title Pallas] Defer tile load masks past transposes onto the sublane axis [Pallas] Defer tile load masks past transposes onto the sublane axis Jun 18, 2026
@ethche ethche force-pushed the pallas-defer-load-mask branch from bb2c328 to 8e1034e Compare June 18, 2026 18:02
A Pallas tiled load masks out-of-bounds lanes multiplicatively in the
load's own layout (ref[idx] * mask). When the value is only relayouted by
an axis permutation (transpose / .T) before a dot or reduction, that
consumer already inserts _mask_to(x, 0), which re-materializes the mask in
the consumer layout. Masking early, in the pre-permute layout, defeats
Mosaic's elision of all-true dynamic masks and is measurably slower
(e.g. a [Q, H, D] -> [H, Q, D] attention load).

The root cause is that a load reports masked_value == 0 unconditionally,
so the downstream _mask_to looks redundant and is dropped, leaving only
the slow eager mask. Triton is unaffected: tl.load(..., other=0) zeroes
real data, so a relayout just permutes already-zero lanes.

Add defer_pallas_load_masks (FX graph pass, Pallas-only): for each load
whose masked tile dim is provably re-masked downstream -- every use reaches
a _mask_to(_, 0) through pure axis-permutation ops with the dim surviving
as a unique tile dim -- record the deferred block ids and mark the load's
masked value unknown so remove_unnecessary_masking keeps the _mask_to.
Load codegen then skips the eager mask for those dims.

Correctness: the crossed ops are restricted to aten.permute.default (pure
axis permutation; view/reshape excluded since they can regroup the masked
dim's elements). Pallas-gated at the call site (marking masked_value on a
Triton load would disable block_ptr lowering).

Profitability is a positional gate on top of that proof. Profiling on TPU
(standard tiled matmul) shows the win is positional: masking a tile axis is
cheap when that axis is in the last-two (sublane/lane vreg-tile) dims and
expensive (~+70% load/scalar traffic, MXU stalls) when it is a major/outer
dim. A transpose relocates the axis, so we defer only when it moves the
masked axis from major (at the load, where eager masking is expensive) to
last-two (at the consumer, where masking is cheap). The reverse direction
would make deferral ~1.5x slower, so it is explicitly not deferred. This
gate also subsumes the 'must cross >=1 relayout' check, given a
rank-preserving relayout set.

Measured: dense-KV GDPA ~1.3x faster (B=768 q4096: 7.06 -> 5.31 ms,
213 -> 283 TFLOP/s); fully-jagged neutral; accuracy unchanged (rel-L2
~0.0023 = bf16 precision). Validated on TPU: full test_pallas.py and the
compact_worklist suite pass.
@ethche ethche force-pushed the pallas-defer-load-mask branch from 8e1034e to 37127e8 Compare June 18, 2026 18:12

@AmesingFlank AmesingFlank left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Altho, do you expect this rewrite to always be profitable? Or do you think there is value is making this tunable?

@ethche

ethche commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

lgtm. Altho, do you expect this rewrite to always be profitable? Or do you think there is value is making this tunable?

The gating is pretty conservative so I think it would only fire when it's reasonable. But I think it would be good for us to keep an eye out on whether there's could be part of a more general optimization pass (e.g. re-ordering operations).

@ethche ethche merged commit e036705 into main Jun 19, 2026
24 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants