Skip to content

[Pallas] Add pallas_loop_type = 'outer_pipeline'#2744

Draft
ethche wants to merge 5 commits into
mainfrom
outer-pipeline-stateful
Draft

[Pallas] Add pallas_loop_type = 'outer_pipeline'#2744
ethche wants to merge 5 commits into
mainfrom
outer-pipeline-stateful

Conversation

@ethche

@ethche ethche commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

This work builds on #2687 to add a Pallas pallas_loop_type = 'outer_pipeline' that folds a top-level hl.grid axis and one nested hl.tile axis into a single pltpu.emit_pipeline launch. It modifies the approach by first gathering all the context from the loops before constructing an emit_pipeline object instead of rewriting the grid code.

We target a pattern with outer work partition expressed with hl.grid, followed by one or more tiled inner dimensions expressed with hl.tile. The outer grid chooses an independent problem instance, group, batch, segment, or ragged row range. The nested tile loop chooses the block of work within that instance. This is relevant for jagged attention, grouped gemm and jagged bmm type kernels.

The current PR implements the stateless/parallel version of this pattern:

for group in hl.grid(num_groups):
    begin = offsets[group]
    end = offsets[group + 1]
    for tile in hl.tile(begin, end, block_size=BLOCK, max_extent=MAX_BLOCK):
        out[group, tile, ...] = compute(inputs[group, tile, ...])

With pallas_loop_type = 'outer_pipeline' we get large improvements in performance (jagged gpda dense kv):

mode emit_pipeline outer_pipeline speedup max_err
uniform 6.682 ms 2.527 ms 2.64x 0.0
ramp 5.308 ms 2.616 ms 2.03x 0.0
half, min_len=8 5.026 ms 2.568 ms 1.96x 0.0

As alluded to above, we need a (slightly) new language feature, a 'max_extent' for hl.tile:

for tile in hl.tile(begin, end, block_size=BLOCK, max_extent=MAX_BLOCK):

A follow up PR and the final goal is to handle the a stateful/ordered version of the pattern:

for group in hl.grid(num_groups):
    m_begin = m_offsets[group]
    m_end = m_offsets[group + 1]
    k_begin = k_offsets[group]
    k_end = k_offsets[group + 1]
    for m_tile in hl.tile(m_begin, m_end, block_size=BLOCK_M, max_extent=MAX_M):
        acc = init_accumulator(...)
        for k_tile in hl.tile(k_begin, k_end, block_size=BLOCK_K, max_extent=MAX_K):
            acc = update(acc, lhs[group, m_tile, k_tile], rhs[group, k_tile, ...])
        out[group, m_tile, ...] = finalize(acc)

This is what's needed for fully jagged gpda, jagged flash attention, and grouped_gemm.

The key abstraction behind this is a PipelineAxis model. Instead of treating the original root grid and the folded inner tile as separate lowering concepts, outer pipeline lowering normalizes both into a list of axes that describe the final emit_pipeline grid.

The key classification is whether an axis is parallel or ordered:

  • parallel means each coordinate of that axis is independent from the others. There is no loop-carried state across the axis, and stores are disjoint per tile. These axes map to dimension_semantics="parallel" and can be freely scheduled by emit_pipeline. Examples are segment/group axes, batch axes, independent M/N output tiles, or a jagged row/query tile where each tile writes its own output slice.
  • ordered means the axis has a recurrence or finalization dependency across tiles. The common shape is an accumulator such as acc = update(acc, tile) followed by a store after the last tile. These axes map to dimension_semantics="arbitrary" because the pipeline must preserve the logical order for that axis. Examples are K tiles in grouped GEMM, reduction tiles in jagged BMM, KV tiles in fully jagged GDPA, and the streaming/reduction axis in flash-style attention.

This PR only supports parallel but a follow up PR will enable ordered.

Here's the lowering flow:

  1. The user selects pallas_loop_type="outer_pipeline".

  2. When the top-level hl.grid is lowered, TileStrategy.codegen_grid() dispatches to _codegen_outer_pipeline_grid() instead of emitting ordinary Pallas program IDs.

    This creates a PipelineContext in CompileEnvironment.outer_pipeline_context.

    For each root grid dimension, it creates a PipelineAxis containing:

    • the Helion block id
    • kind="parallel"
    • begin/end/step/block extent
    • static pipeline-grid extent
    • the BlockSpec lambda parameter, e.g. _o0
    • the original pid/offset/index variable names

    It also records prologue statements that reconstruct the original pid, offset, index, and mask variables from _pipeline_indices[...] inside the eventual pipeline body.

  3. Normal lowering then continues through the grid body.

    Statements before the folded hl.tile loop are captured as pipeline prologue. Simple scalar assignments are recorded so they can be replayed in the pipeline body and also inlined into BlockSpec lambdas. This is what lets bounds like start = offsets[seq] be used safely inside generated BlockSpecs.

  4. When the nested hl.tile reaches _codegen_emit_pipeline(), the existing emit-pipeline lowering switches into outer-pipeline mode because env.outer_pipeline_context is set.

    At this point the lowering:

    • stops prologue capture
    • rejects a second folded tile loop
    • rejects loop-carried state for this first PR (TODO: capture this as an ordered axis)
    • computes static max_tiles for the folded tile from the loop extent or hl.tile(..., max_extent=...)
    • adds another PipelineAxis for the folded tile, usually with lambda parameter _j

    The resulting pipeline grid is just:

    outer_context.grid_parts

    so for example

    grid = (seq_tiles, q_tiles)
    dimension_semantics = ("parallel", "parallel")
  5. The folded loop body is validated before codegen.

    The first PR keeps this conservative. It rejects:

    • atomics
    • HBM read-modify-write
    • reductions inside the folded body (TODO: allow structured ordered-axis reductions)
    • loop-carried state (TODO: allow VMEM scratch carry for ordered axes)
    • stores that do not reference the folded output axis
    • unsupported prologue tensor reads
    • non-pipelined folded-loop tensor accesses
  6. Tensor accesses are classified into emit-pipeline inputs and outputs.

    For each pipelined tensor, _make_block_spec() builds a Pallas BlockSpec by asking the PipelineContext which PipelineAxis owns each tensor dimension.

    The important part is that BlockSpec generation no longer has separate "outer" and "inner" special cases. For an axis-backed dimension it uses:

    raw_start = axis.begin + axis.lambda_param * axis.step

    Then it chooses either:

    • compact block-index form for static full, divisible dimensions
    • pl.BoundedSlice(...) plus clamped pl.ds(start, extent) for jagged, partial, or overlaunched dimensions

    Lambda rendering goes through PipelineContext.resolve_for_lambda(), which substitutes pid/offset variables with lambda-scope expressions and inlines captured scalar prologue assignments.

  7. The pipeline body function is generated.

    The body takes _pipeline_indices plus the VMEM refs produced by emit_pipeline.

    Its prologue reconstructs the original grid/tile variables, emits masks, replays captured scalar prologue statements, and remaps any supported prologue tensor reads to VMEM refs. Then the original folded loop graph is emitted against those VMEM refs through the existing EmitPipelineLoopState machinery.

    For overlaunched tiles, the real body statements are wrapped in a validity guard:

    lax.cond(valid_tile, real_body, empty_body)

    Loads also keep conservative masking for invalid tail lanes. Reductions are rejected for now because masking only the load is not enough for arbitrary lane mixing.

  8. Finally _codegen_emit_pipeline() emits one generated Pallas call:

    pltpu.emit_pipeline(
        body_fn,
        grid=(...outer axes..., ...folded tile axes...),
        in_specs=[...],
        out_specs=[...],
        _explicit_indices=True,
        dimension_semantics=outer_context.dimension_semantics,
    )(...)

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@ethche

ethche commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

@AmesingFlank this is building off of your approach. let me know if addresses your concerns about "overfitting" to the gdpa case.

@AmesingFlank

Copy link
Copy Markdown
Contributor

Thanks for working on this! When I played with outer_pipeline previously, I was modeling after this gist, but later on I created this gist which doesn't use a outer size-1 grid followed by an emit_pipeline, and instead just uses a single pallas_call grid. I do think the 2nd approach is more similar to how we currently generate code (pallas_call grid size == num seqeunces), and it does have slightly better TFLOPs, so I was thinking perhaps its preferable to use that approach. Wdyt?

@ethche

ethche commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

EDIT: revised to do a better distinction.

@AmesingFlank Thanks that's very helpful. I guess my attempt here was to have a pipeline option that would work well across jagged gdpa, jagged flash attention, grouped gemm etc. The current gists are somewhat specialized in that they only work when the whole sequence fits into vmem.

To have something general we need to handle:

for group in hl.grid(num_groups)
	begin = offsets[group]
	end = offsets[group + 1]
	for tile_m in hl.tile(begin, end):
		acc = 0
		for tile_k in hl.tile(k_begin,k_end)
			acc += ...

The pallas reference kernels don't have an inner accumulator loop, so we'll have to do something different.

My approach in this PR was to incorporate all of these into a single emit_pipeline over num_groups, tile_m, tile_k:

emit_pipeline(... dimension_semantics = 'parallel', 'parallel', 'arbitrary')

It gets decent perf (up to 2.0 ms compared to 1.86 ms of the reference kernel), but I'm not sure if this is the best way. After thinking about it, I do think the right approach would be some extension of the gist you shared but maybe more like:

grid = (B * max_q_tiles,)
emit_pipeline for kv ("arbitrary")

@thcmbs what do you think?

@ethche ethche marked this pull request as draft June 11, 2026 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants