Skip to content

Commit 26cd6c6

Browse files
committed
Cache CUDA device capability lookups on the kernel dispatch hot path
Every Helion kernel call recomputes the bound-kernel cache key, which includes the device's compute capability via target_device_capability(). That function called torch.cuda.is_available() plus torch.cuda.get_device_capability() on every single kernel launch, costing ~2.8us/call of pure Python/CUDA-runtime overhead -- the single largest line item in Helion-side dispatch. A physical device's compute capability cannot change within a process, so cache it per device index in a module-level dict, and cache the is_available() result once it returns True (an unavailable runtime has no GPU launches to speed up, so False keeps re-querying). Safety: tests simulate other architectures by patching torch.cuda.get_device_capability / torch.cuda.is_available (e.g. test_config_api.py's sm90-vs-sm100 cache-key test). The fast path engages only while both functions are identity-equal to the originals captured at import time; patched functions route to the original uncached path, so the cache can never be poisoned by, nor serve stale values to, arch-simulation tests. An index-less torch.device("cuda") refers to the *current* device, which can change between calls -- it is resolved per call via torch.cuda.current_device() (cheap) before hitting the per-index cache. Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters): ``` n_args | before | after | delta -------+----------+----------+------- 2 | 24.14 us | 18.78 us | -22% 8 | 33.05 us | 27.77 us | -16% 16 | 43.56 us | 36.90 us | -15% ``` Benchmark script (used for this and the follow-up launch-overhead commits; wall-clock around the Python call deliberately captures CPU-side dispatch cost, which CUDA-event timing excludes -- the kernel is tiny, so the loop is CPU-bound and measures per-call issue cost): ```python import os, sys os.environ.setdefault("HELION_AUTOTUNE_EFFORT", "none") import time, torch, helion, helion.language as hl # Write kernels to a real module file: helion needs # inspect.getsource, so exec()'d kernels don't work. mod_src = ["import torch, helion, helion.language as hl\n"] for n_args in (2, 8, 16): src_args = ", ".join(f"t{i}" for i in range(n_args)) mod_src.append(f''' @helion.kernel(config=helion.Config(block_sizes=[128], num_warps=4, num_stages=2)) def k{n_args}({src_args}): out = torch.empty_like(t0) for tile in hl.tile(t0.size(0)): out[tile] = {' + '.join(f't{i}[tile]' for i in range(n_args))} return out ''') open("/tmp/bench_kernels.py", "w").write("\n".join(mod_src)) sys.path.insert(0, "/tmp") import bench_kernels def bench(n_args, n=20000): k = getattr(bench_kernels, f"k{n_args}") ts = [torch.randn(4096, device="cuda") for _ in range(n_args)] for _ in range(50): # warmup: compile + prime launcher k(*ts) torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(n): k(*ts) dt = (time.perf_counter() - t0) / n * 1e6 torch.cuda.synchronize() return dt for n_args in (2, 8, 16): print(f"n_args={n_args:3d}: {bench(n_args):6.2f} us/call") ``` Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> stack-info: PR: #2746, branch: yushangdi/stack/28
1 parent 037efd0 commit 26cd6c6

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

helion/_compat.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,64 @@ def supports_tensor_descriptor() -> bool:
347347
return _supports_tensor_descriptor()
348348

349349

350+
# Per-device-index cache for ``target_device_capability``. A physical
351+
# device's compute capability cannot change within a process (torch itself
352+
# caches ``device_count`` behind ``is_available``), but the query costs
353+
# ~2.5us and sits on the per-call kernel dispatch path via
354+
# ``_device_specialization_key``. Entries are only read/written while
355+
# ``torch.cuda.is_available`` / ``torch.cuda.get_device_capability`` are
356+
# the original functions — tests patch them to simulate other
357+
# architectures, and the identity check below routes patched calls to the
358+
# uncached path.
359+
#
360+
# NB: a plain ``@functools.cache`` on this function is wrong on two counts,
361+
# which is why the cache is hand-rolled:
362+
# 1. It keys on the ``device`` argument, not on whether the torch
363+
# functions are patched, so it cannot fall back to a live query.
364+
# Tests that patch ``get_device_capability`` to return (9, 0) then
365+
# (10, 0) for the same device (e.g. test_config_api's sm90-vs-sm100
366+
# cache-key test) would get the first cached value both times.
367+
# 2. ``device=None`` means "the current device", which can change via
368+
# ``torch.cuda.set_device``. functools.cache would freeze the first
369+
# answer under the ``None`` key forever; here ``None`` is resolved to
370+
# a concrete index per call before the cache lookup.
371+
_REAL_CUDA_IS_AVAILABLE: Callable[[], bool] = torch.cuda.is_available
372+
_REAL_CUDA_GET_CAPABILITY: Callable[..., tuple[int, int]] = (
373+
torch.cuda.get_device_capability
374+
)
375+
_CUDA_CAPABILITY_CACHE: dict[int, tuple[int, int]] = {}
376+
_CUDA_AVAILABLE: bool | None = None
377+
378+
350379
def target_device_capability(
351380
device: torch.device | None = None,
352381
) -> tuple[int, int] | None:
353382
"""Return CUDA compute capability, or None for non-CUDA/unavailable targets."""
354383
if device is not None and device.type != "cuda":
355384
return None
385+
if (
386+
torch.cuda.is_available is _REAL_CUDA_IS_AVAILABLE
387+
and torch.cuda.get_device_capability is _REAL_CUDA_GET_CAPABILITY
388+
):
389+
global _CUDA_AVAILABLE
390+
if _CUDA_AVAILABLE is not True:
391+
# Only cache the True result; an unavailable runtime has no GPU
392+
# launches to speed up, so keep re-querying in that case.
393+
if not _REAL_CUDA_IS_AVAILABLE():
394+
return None
395+
_CUDA_AVAILABLE = True
396+
index = device.index if device is not None else None
397+
if index is None:
398+
# An index-less device refers to the *current* device, which can
399+
# change between calls — resolve it each time (cheap), then hit
400+
# the per-index cache.
401+
index = torch.cuda.current_device()
402+
capability = _CUDA_CAPABILITY_CACHE.get(index)
403+
if capability is None:
404+
_CUDA_CAPABILITY_CACHE[index] = capability = _REAL_CUDA_GET_CAPABILITY(
405+
index
406+
)
407+
return capability
356408
if not torch.cuda.is_available():
357409
return None
358410
if device is None:

0 commit comments

Comments
 (0)