Skip to content

Commit 09aea3f

Browse files
committed
Use plain tuple keys for the in-memory bound-kernel cache
Kernel.bind runs on every kernel call and the cache *hit* is the steady state, so the per-call lookup key should be as cheap to build as possible. _get_bound_kernel_cache_key constructed a frozen-dataclass BoundKernelInMemoryCacheKey on every call: a lazy `from ..autotuner.base_cache import ...` import, a dataclass __init__, two frozen-field object.__setattr__ overrides, and a generated __hash__ that re-walks the fields. The in-memory _bound_kernels dict only needs *some* hashable key, so the per-call path now uses the equivalent plain (signature, extra_results) tuple. In isolation this drops _get_bound_kernel_cache_key from ~0.93us to ~0.22us per call. The dataclass form is still produced by _create_bound_kernel_cache_key for the autotuner caches (LocalAutotuneCache / AOTAutotuneCache subclass it into LooseAutotuneCacheKey); only the in-memory dict switches to tuple keys. On the compile (cache-miss) path the dataclass key is built once and unpacked into its (specialization_key, extra_results) tuple form so the in-memory dict and the autotuner caches stay keyed on the same value. Also drops one extra-results tuple allocation when a kernel has no hl.specialize() extras (the common case): the empty extra_fns list short-circuits to a shared () literal instead of tuple([]). Safety: cache-key *contents* are unchanged -- same signature tuple, same extra results, same specialization axes (dtype, shape bucket, device type+capability, ConstExpr values, key= fn, hl.specialize extras); the tuple is just the dataclass's two fields in order, with identical hash/equality. Verified by test_misc, test_config_api, test_cache, test_specialize (132 passed), plus dtype/shape/ConstExpr/specialize rebinding spot-checks. Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters): ``` n_args | baseline | prev commit | this commit -------+----------+-------------+------------ 2 | 24.14 us | 18.78 us | 16.41 us 8 | 33.05 us | 27.77 us | 24.78 us 16 | 43.56 us | 36.90 us | 35.24 us ``` Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> stack-info: PR: #2747, branch: yushangdi/stack/29
1 parent 26cd6c6 commit 09aea3f

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

helion/runtime/kernel.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ def __init__(
185185
Config(**config) if isinstance(config, dict) else config
186186
for config in configs or []
187187
]
188-
self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {}
188+
# Keyed by ``(signature, extra_results)`` tuples — the tuple form of
189+
# ``BoundKernelInMemoryCacheKey``, kept plain for cheap per-call
190+
# construction in ``bind``.
191+
self._bound_kernels: dict[Hashable, BoundKernel] = {}
189192
self._specialize_extra: dict[
190193
Hashable, list[Callable[[Sequence[object]], Hashable]]
191194
] = {}
@@ -226,13 +229,15 @@ def __init__(
226229

227230
def _get_bound_kernel_cache_key(
228231
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
229-
) -> BoundKernelInMemoryCacheKey | None:
230-
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
231-
232+
) -> tuple[Hashable, ...] | None:
233+
# Plain tuples keep the per-call dispatch path free of dataclass
234+
# construction; `_create_bound_kernel_cache_key` provides the
235+
# `BoundKernelInMemoryCacheKey` form for the autotuner caches.
232236
extra_fns = self._specialize_extra.get(signature)
233237
if extra_fns is not None:
234-
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
235-
return BoundKernelInMemoryCacheKey(signature, extra_results)
238+
if extra_fns:
239+
return (signature, tuple([s(args) for s in extra_fns]))
240+
return (signature, ())
236241
return None
237242

238243
def _create_bound_kernel_cache_key(
@@ -278,9 +283,10 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
278283
else:
279284
bound_kernel = BoundKernel(self, args)
280285
if cache_key is None:
281-
cache_key = self._create_bound_kernel_cache_key(
286+
full_key = self._create_bound_kernel_cache_key(
282287
bound_kernel, args, signature
283288
)
289+
cache_key = (full_key.specialization_key, full_key.extra_results)
284290
self._bound_kernels[cache_key] = bound_kernel
285291
return bound_kernel
286292

0 commit comments

Comments
 (0)