Skip to content

Move measure("Kernel.bind") off the cache-hit dispatch path#2751

Draft
yushangdi wants to merge 5 commits into
yushangdi/stack/29from
yushangdi/stack/32
Draft

Move measure("Kernel.bind") off the cache-hit dispatch path#2751
yushangdi wants to merge 5 commits into
yushangdi/stack/29from
yushangdi/stack/32

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


Move measure("Kernel.bind") off the cache-hit dispatch path

measure() wrapped the entire body of Kernel.bind, which runs on every
kernel call. When HELION_MEASURE_COMPILE_TIME is unset (the default)
measure() returns a shared no-op context manager, but entering and
exiting it still costs ~0.4-0.5us/call: two extra frames plus a module
global read, on the steady-state cache-hit path.

Compile-time tracking only has meaningful data on the cache-miss
branch, where bind actually compiles and specializes a new
BoundKernel. The hit path is a specialization-key computation plus a
dict lookup -- nanoseconds of work that the tracker was never meant to
account for. Move the with measure("Kernel.bind"): block to wrap
only the miss branch; the hit path returns directly with no context
manager.

No behavior change: the same code runs under the same measurement
scope on the miss path, and HELION_MEASURE_COMPILE_TIME=1 still
attributes all compile/specialize time to "Kernel.bind". Verified by
test_misc, test_config_api, test_cache, test_specialize.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    16.41 us |    16.30 us
       8 | 33.05 us |    24.78 us |    24.39 us
      16 | 43.56 us |    35.24 us |    34.84 us

(The per-call win is ~0.4us isolated to bind(); at the end-to-end
scale here it is partly absorbed into run-to-run noise.)

Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com

measure("Kernel.bind") wraps the whole body of Kernel.bind, which runs
on every kernel call. Even when HELION_MEASURE_COMPILE_TIME is unset
(the default) and measure() returns a shared no-op nullcontext,
entering and exiting the `with` block still costs ~115ns/call -- the
context-manager protocol itself, not the measurement -- on the
steady-state cache-hit dispatch path.

Gate it: bind() checks _compile_time.is_enabled() (a plain module-flag
read, ~5ns) and only enters the `with measure(...)` block when
measurement is actually on; otherwise it calls the extracted
_bind_impl directly. The measured region is unchanged -- when enabled,
the entire bind body (key computation, cache lookup, and any compile)
runs under the same "Kernel.bind" scope as before, and every bind()
call is still counted, including cache hits. Only the
nothing-to-measure fast path skips the context manager.

No behavior change: with HELION_MEASURE_COMPILE_TIME=1 the timing
report is identical to before (same scope, same call counts); with it
unset there was never any data recorded, only the protocol overhead
that is now avoided. Verified by test_compile_time (metric still
records when enabled), plus test_misc / test_config_api / test_cache /
test_specialize.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    18.04 us |    18.01 us
       8 | 33.05 us |    26.11 us |    25.57 us
      16 | 43.56 us |    36.53 us |    36.41 us
```

(The per-call win is ~0.4us isolated to bind(); at the end-to-end
scale here it is within run-to-run noise.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2752, branch: yushangdi/stack/33
_tensor_key wraps every size and stride element in _hashable_dim to
normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That
wrap exists only for symbolic shapes, which appear on FakeTensors
during tracing -- yet concrete tensors paid for it on every kernel
call: two Python-level loops (sizes + strides) with an isinstance
check per dimension, plus rebuilding each as a fresh tuple. For the
default static_shapes=True path this was the bulk of per-tensor key
extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type
(_specialization_extractors.get(type(obj))), so the torch.Tensor and
torch.nn.Parameter entries are only ever hit by objects whose
sizes/strides are guaranteed concrete ints. Route those to a new
_concrete_tensor_key that uses obj.size() (a torch.Size) and
obj.stride() directly as key components. Both are tuple subclasses:
they hash and compare identically to plain int tuples, so keys
produced by the fast path and the SymInt path for the same concrete
shape are interchangeable -- no cache invalidation, and the
specialization axes (dtype, shape, stride, _dynamo_static_indices,
int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:
 * FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
 * torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the
   exact-type dict and hit the isinstance fallback in
   _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream
PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence
(fast vs wrapped key hash/compare equal under static and dynamic
shapes, incl. strided tensors), the dispatch routing (concrete ->
_concrete_tensor_key; FakeTensor and the subclass fallback ->
_tensor_key), that a torch.Tensor subclass takes the SymInt-safe path
and still shares a BoundKernel with a plain tensor of the same shape,
and that bind() caching/distinguishing is unchanged. Also verified
against test_misc, test_specialize, test_torch_compile (the
torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    18.01 us |    17.25 us
       8 | 33.05 us |    25.57 us |    24.50 us
      16 | 43.56 us |    36.41 us |    32.16 us
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2748, branch: yushangdi/stack/30
Every Helion kernel launch went through Triton's full JITFunction.run
pipeline (~9.3us): per-call device + stream proxy resolution, the
argument binder, compute_cache_key, kernel-cache lookup, the
used_global_vals walk, launch_metadata construction (even with no
profiler attached), and the kwargs-dict munging around all of it. For
a Helion kernel almost all of that is redundant: BoundKernel has
already specialized on dtype/shape/stride/device, so the only
Triton-level specialization left at launch time is pointer alignment
and binary-affecting knob state.

This ports the _FastLauncher design from upstream PR #2565 (plus the
set_config function-clone fix from PR #2635) onto main:

 * helion/runtime/_fast_launcher.py -- default_launcher moves here
   unchanged (still re-exported from helion.runtime), joined by
   _FastLauncher: a multi-spec launcher primed on first call. The hot
   path computes a tiny spec key inline -- an alignment bitmask over
   the tensor args (data_ptr() & 15) plus debug/instrumentation_mode/
   stages-hook knob state -- dict-looks-up the compiled binary for
   that spec, and jumps straight into Triton's C launcher
   (CompiledKernel.run). Spec misses compile through Triton's full
   pipeline once and are cached, so call sites alternating aligned/
   unaligned tensors stay on the fast path for both.

 * BoundKernel.set_config clones the PyCodeCache'd host function
   (PyCodeCache keys on source hash, so two BoundKernels can share one
   function object) and re-points its _launcher kwdefault at a
   _FastLauncher with the config's num_warps/num_stages/etc. baked in.
   Explicit _launcher= callers (the autotune trial harness) override
   the kwdefault naturally.

 * TritonBackend.launcher_runtime_kwargs factors the runtime kwarg
   values out of launcher_keyword_args so codegen strings and the
   launcher closure share one source of truth.

Safety / correctness guards, each pinned by a test in
test/test_fast_launcher.py:
 * Alignment is part of the spec key, so an unaligned tensor after an
   aligned prime gets its own correctly-compiled binary -- never the
   vectorized aligned binary (which would fault), and never a clone
   (which would silently drop writes to output args).
 * used_global_vals snapshot per spec entry; any mutation falls back
   to JITFunction.run so Triton's own RuntimeError surfaces instead of
   silently launching a stale binary.
 * torch.compile tracing routes through default_launcher so Dynamo's
   triton_kernel_wrapper_mutation HOP rules apply.
 * Multi-device guard: a current-device change after priming falls
   back to Triton's per-device dispatch.
 * launch_enter/exit hooks are re-read per call (a profiler attached
   after priming still fires; launch_metadata is built only when a
   hook will consume it); pre_run_hooks fire inline; flipping
   knobs.runtime.debug lands on a new spec entry and recompiles.
 * Any priming/compile failure, and HELION_SKIP_FAST_LAUNCHER=1, fall
   back to default_launcher permanently.

Verified: test_fast_launcher + test_misc (51 passed), and full runs of
test_torch_compile (244), test_examples (96), test_autotuner (122),
test_indexing, test_loops, test_grid, test_ref_eager, test_specialize,
test_config_api, test_cache earlier on this branch.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit | total
  -------+----------+-------------+-------------+------
       2 | 24.14 us |    17.25 us |    13.63 us |  -44%
       8 | 33.05 us |    24.50 us |    18.99 us |  -43%
      16 | 43.56 us |    32.16 us |    25.73 us |  -41%
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2749, branch: yushangdi/stack/31
Every Helion kernel call recomputes the bound-kernel cache key, which
includes the device's compute capability via target_device_capability().
That function called torch.cuda.is_available() plus
torch.cuda.get_device_capability() on every single kernel launch,
costing ~2.8us/call of pure Python/CUDA-runtime overhead -- the single
largest line item in Helion-side dispatch.

A physical device's compute capability cannot change within a process,
so cache it per device index in a module-level dict, and cache the
is_available() result once it returns True (an unavailable runtime has
no GPU launches to speed up, so False keeps re-querying).

Safety: tests simulate other architectures by patching
torch.cuda.get_device_capability / torch.cuda.is_available (e.g.
test_config_api.py's sm90-vs-sm100 cache-key test). The fast path
engages only while both functions are identity-equal to the originals
captured at import time; patched functions route to the original
uncached path, so the cache can never be poisoned by, nor serve stale
values to, arch-simulation tests.

An index-less torch.device("cuda") refers to the *current* device,
which can change between calls -- it is resolved per call via
torch.cuda.current_device() (cheap) before hitting the per-index cache.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args |   before |    after |  delta
  -------+----------+----------+-------
       2 | 24.14 us | 18.78 us |   -22%
       8 | 33.05 us | 27.77 us |   -16%
      16 | 43.56 us | 36.90 us |   -15%
```

Benchmark script (used for this and the follow-up launch-overhead
commits; wall-clock around the Python call deliberately captures
CPU-side dispatch cost, which CUDA-event timing excludes -- the kernel
is tiny, so the loop is CPU-bound and measures per-call issue cost):

```python
import os, sys
os.environ.setdefault("HELION_AUTOTUNE_EFFORT", "none")
import time, torch, helion, helion.language as hl

mod_src = ["import torch, helion, helion.language as hl\n"]
for n_args in (2, 8, 16):
    src_args = ", ".join(f"t{i}" for i in range(n_args))
    mod_src.append(f'''
@helion.kernel(config=helion.Config(block_sizes=[128], num_warps=4, num_stages=2))
def k{n_args}({src_args}):
    out = torch.empty_like(t0)
    for tile in hl.tile(t0.size(0)):
        out[tile] = {' + '.join(f't{i}[tile]' for i in range(n_args))}
    return out
''')
open("/tmp/bench_kernels.py", "w").write("\n".join(mod_src))
sys.path.insert(0, "/tmp")
import bench_kernels

def bench(n_args, n=20000):
    k = getattr(bench_kernels, f"k{n_args}")
    ts = [torch.randn(4096, device="cuda") for _ in range(n_args)]
    for _ in range(50):  # warmup: compile + prime launcher
        k(*ts)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(n):
        k(*ts)
    dt = (time.perf_counter() - t0) / n * 1e6
    torch.cuda.synchronize()
    return dt

for n_args in (2, 8, 16):
    print(f"n_args={n_args:3d}: {bench(n_args):6.2f} us/call")
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2746, branch: yushangdi/stack/28
measure() wrapped the entire body of Kernel.bind, which runs on every
kernel call. When HELION_MEASURE_COMPILE_TIME is unset (the default)
measure() returns a shared no-op context manager, but entering and
exiting it still costs ~0.4-0.5us/call: two extra frames plus a module
global read, on the steady-state cache-hit path.

Compile-time tracking only has meaningful data on the cache-*miss*
branch, where bind actually compiles and specializes a new
BoundKernel. The hit path is a specialization-key computation plus a
dict lookup -- nanoseconds of work that the tracker was never meant to
account for. Move the `with measure("Kernel.bind"):` block to wrap
only the miss branch; the hit path returns directly with no context
manager.

No behavior change: the same code runs under the same measurement
scope on the miss path, and HELION_MEASURE_COMPILE_TIME=1 still
attributes all compile/specialize time to "Kernel.bind". Verified by
test_misc, test_config_api, test_cache, test_specialize.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    16.41 us |    16.30 us
       8 | 33.05 us |    24.78 us |    24.39 us
      16 | 43.56 us |    35.24 us |    34.84 us
```

(The per-call win is ~0.4us isolated to bind(); at the end-to-end
scale here it is partly absorbed into run-to-run noise.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2751, branch: yushangdi/stack/32
@yushangdi yushangdi force-pushed the yushangdi/stack/32 branch from 043b1a3 to 055364f Compare June 11, 2026 17:09
yushangdi added a commit that referenced this pull request Jun 11, 2026
measure() wrapped the entire body of Kernel.bind, which runs on every
kernel call. When HELION_MEASURE_COMPILE_TIME is unset (the default)
measure() returns a shared no-op context manager, but entering and
exiting it still costs ~0.4-0.5us/call: two extra frames plus a module
global read, on the steady-state cache-hit path.

Compile-time tracking only has meaningful data on the cache-*miss*
branch, where bind actually compiles and specializes a new
BoundKernel. The hit path is a specialization-key computation plus a
dict lookup -- nanoseconds of work that the tracker was never meant to
account for. Move the `with measure("Kernel.bind"):` block to wrap
only the miss branch; the hit path returns directly with no context
manager.

No behavior change: the same code runs under the same measurement
scope on the miss path, and HELION_MEASURE_COMPILE_TIME=1 still
attributes all compile/specialize time to "Kernel.bind". Verified by
test_misc, test_config_api, test_cache, test_specialize.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    16.41 us |    16.30 us
       8 | 33.05 us |    24.78 us |    24.39 us
      16 | 43.56 us |    35.24 us |    34.84 us
```

(The per-call win is ~0.4us isolated to bind(); at the end-to-end
scale here it is partly absorbed into run-to-run noise.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2751, branch: yushangdi/stack/32
yushangdi added a commit that referenced this pull request Jun 11, 2026
measure() wrapped the entire body of Kernel.bind, which runs on every
kernel call. When HELION_MEASURE_COMPILE_TIME is unset (the default)
measure() returns a shared no-op context manager, but entering and
exiting it still costs ~0.4-0.5us/call: two extra frames plus a module
global read, on the steady-state cache-hit path.

Compile-time tracking only has meaningful data on the cache-*miss*
branch, where bind actually compiles and specializes a new
BoundKernel. The hit path is a specialization-key computation plus a
dict lookup -- nanoseconds of work that the tracker was never meant to
account for. Move the `with measure("Kernel.bind"):` block to wrap
only the miss branch; the hit path returns directly with no context
manager.

No behavior change: the same code runs under the same measurement
scope on the miss path, and HELION_MEASURE_COMPILE_TIME=1 still
attributes all compile/specialize time to "Kernel.bind". Verified by
test_misc, test_config_api, test_cache, test_specialize.

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    16.41 us |    16.30 us
       8 | 33.05 us |    24.78 us |    24.39 us
      16 | 43.56 us |    35.24 us |    34.84 us
```

(The per-call win is ~0.4us isolated to bind(); at the end-to-end
scale here it is partly absorbed into run-to-run noise.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2751, branch: yushangdi/stack/32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant