Skip to content

autotune key hygiene: apply mask/dtype-placeholder fix to fp8.py and fused_head.py #101

Description

@h-aurelien-lac

Follow-up from #100 (Tony's review).

#100 made the dense forward/backward autotune key invariant to mask/argmax presence by (a) dropping has_q_mask/has_d_mask from the named key and (b) using dtype-matched placeholders for absent optional args (Triton's autotuner appends str(arg.dtype) of every tensor arg to its cache key, so a bf16 Q stand-in vs a real int8 mask splits the cache).

Two other kernels keep the old pattern:

  • fp8.py_maxsim_fp8_fwd_kernel key (l.74) keeps has_q_mask/has_d_mask; placeholders are else Q / else D (l.344-345, bf16 vs int8).
  • fused_head.py_fused_head_fwd_kernel key (l.40) keys on has_bias/has_d_mask/normalize/save_argmax; placeholders else Q / else H_d / else scores (l.214-217). This is a training path.

Neither routes through _bucket_lq, so mask presence doesn't flip mid-run and there's no autotune-spike scenario like the docvqa one #100 fixed — this is cache-cardinality hygiene, not a perf regression. Scoped out of the 0.4.1 patch because each kernel needs its own parity validation.

Work:

  • Drop the constexpr mask/bias/normalize/save_argmax toggles from both keys (they change codegen, not the winning tile — same argument as the dense forward).
  • Use autotune_placeholder(ref, dtype) for absent optional args (int8 masks, int32 argmax, matching-dtype bias).
  • Add compile-cache regression tests for both kernels.
  • Validate parity on H100.

plaid.py:898 already uses the dtype-matched placeholder, for reference.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions