[Pallas] Add pallas_loop_type = 'outer_pipeline'#2744
Conversation
|
@AmesingFlank this is building off of your approach. let me know if addresses your concerns about "overfitting" to the gdpa case. |
|
Thanks for working on this! When I played with |
|
EDIT: revised to do a better distinction. @AmesingFlank Thanks that's very helpful. I guess my attempt here was to have a pipeline option that would work well across jagged gdpa, jagged flash attention, grouped gemm etc. The current gists are somewhat specialized in that they only work when the whole sequence fits into vmem. To have something general we need to handle: for group in hl.grid(num_groups)
begin = offsets[group]
end = offsets[group + 1]
for tile_m in hl.tile(begin, end):
acc = 0
for tile_k in hl.tile(k_begin,k_end)
acc += ...The pallas reference kernels don't have an inner accumulator loop, so we'll have to do something different. My approach in this PR was to incorporate all of these into a single emit_pipeline(... dimension_semantics = 'parallel', 'parallel', 'arbitrary')It gets decent perf (up to 2.0 ms compared to 1.86 ms of the reference kernel), but I'm not sure if this is the best way. After thinking about it, I do think the right approach would be some extension of the gist you shared but maybe more like: grid = (B * max_q_tiles,) @thcmbs what do you think? |
This work builds on #2687 to add a Pallas
pallas_loop_type = 'outer_pipeline'that folds a top-levelhl.gridaxis and one nestedhl.tileaxis into a singlepltpu.emit_pipelinelaunch. It modifies the approach by first gathering all the context from the loops before constructing an emit_pipeline object instead of rewriting the grid code.We target a pattern with outer work partition expressed with
hl.grid, followed by one or more tiled inner dimensions expressed withhl.tile. The outer grid chooses an independent problem instance, group, batch, segment, or ragged row range. The nested tile loop chooses the block of work within that instance. This is relevant for jagged attention, grouped gemm and jagged bmm type kernels.The current PR implements the stateless/parallel version of this pattern:
With
pallas_loop_type = 'outer_pipeline'we get large improvements in performance (jagged gpda dense kv):As alluded to above, we need a (slightly) new language feature, a 'max_extent' for hl.tile:
A follow up PR and the final goal is to handle the a stateful/ordered version of the pattern:
This is what's needed for fully jagged gpda, jagged flash attention, and grouped_gemm.
The key abstraction behind this is a
PipelineAxismodel. Instead of treating the original root grid and the folded inner tile as separate lowering concepts, outer pipeline lowering normalizes both into a list of axes that describe the finalemit_pipelinegrid.The key classification is whether an axis is
parallelorordered:parallelmeans each coordinate of that axis is independent from the others. There is no loop-carried state across the axis, and stores are disjoint per tile. These axes map todimension_semantics="parallel"and can be freely scheduled byemit_pipeline. Examples are segment/group axes, batch axes, independent M/N output tiles, or a jagged row/query tile where each tile writes its own output slice.orderedmeans the axis has a recurrence or finalization dependency across tiles. The common shape is an accumulator such asacc = update(acc, tile)followed by a store after the last tile. These axes map todimension_semantics="arbitrary"because the pipeline must preserve the logical order for that axis. Examples are K tiles in grouped GEMM, reduction tiles in jagged BMM, KV tiles in fully jagged GDPA, and the streaming/reduction axis in flash-style attention.This PR only supports
parallelbut a follow up PR will enableordered.Here's the lowering flow:
The user selects
pallas_loop_type="outer_pipeline".When the top-level
hl.gridis lowered,TileStrategy.codegen_grid()dispatches to_codegen_outer_pipeline_grid()instead of emitting ordinary Pallas program IDs.This creates a
PipelineContextinCompileEnvironment.outer_pipeline_context.For each root grid dimension, it creates a
PipelineAxiscontaining:kind="parallel"_o0It also records prologue statements that reconstruct the original
pid,offset,index, and mask variables from_pipeline_indices[...]inside the eventual pipeline body.Normal lowering then continues through the grid body.
Statements before the folded
hl.tileloop are captured as pipeline prologue. Simple scalar assignments are recorded so they can be replayed in the pipeline body and also inlined into BlockSpec lambdas. This is what lets bounds likestart = offsets[seq]be used safely inside generated BlockSpecs.When the nested
hl.tilereaches_codegen_emit_pipeline(), the existing emit-pipeline lowering switches into outer-pipeline mode becauseenv.outer_pipeline_contextis set.At this point the lowering:
max_tilesfor the folded tile from the loop extent orhl.tile(..., max_extent=...)PipelineAxisfor the folded tile, usually with lambda parameter_jThe resulting pipeline grid is just:
so for example
The folded loop body is validated before codegen.
The first PR keeps this conservative. It rejects:
Tensor accesses are classified into emit-pipeline inputs and outputs.
For each pipelined tensor,
_make_block_spec()builds a PallasBlockSpecby asking thePipelineContextwhichPipelineAxisowns each tensor dimension.The important part is that BlockSpec generation no longer has separate "outer" and "inner" special cases. For an axis-backed dimension it uses:
Then it chooses either:
pl.BoundedSlice(...)plus clampedpl.ds(start, extent)for jagged, partial, or overlaunched dimensionsLambda rendering goes through
PipelineContext.resolve_for_lambda(), which substitutes pid/offset variables with lambda-scope expressions and inlines captured scalar prologue assignments.The pipeline body function is generated.
The body takes
_pipeline_indicesplus the VMEM refs produced byemit_pipeline.Its prologue reconstructs the original grid/tile variables, emits masks, replays captured scalar prologue statements, and remaps any supported prologue tensor reads to VMEM refs. Then the original folded loop graph is emitted against those VMEM refs through the existing
EmitPipelineLoopStatemachinery.For overlaunched tiles, the real body statements are wrapped in a validity guard:
Loads also keep conservative masking for invalid tail lanes. Reductions are rejected for now because masking only the load is not enough for arbitrary lane mixing.
Finally
_codegen_emit_pipeline()emits one generated Pallas call: