Commit 26cd6c6
committed
Cache CUDA device capability lookups on the kernel dispatch hot path
Every Helion kernel call recomputes the bound-kernel cache key, which
includes the device's compute capability via target_device_capability().
That function called torch.cuda.is_available() plus
torch.cuda.get_device_capability() on every single kernel launch,
costing ~2.8us/call of pure Python/CUDA-runtime overhead -- the single
largest line item in Helion-side dispatch.
A physical device's compute capability cannot change within a process,
so cache it per device index in a module-level dict, and cache the
is_available() result once it returns True (an unavailable runtime has
no GPU launches to speed up, so False keeps re-querying).
Safety: tests simulate other architectures by patching
torch.cuda.get_device_capability / torch.cuda.is_available (e.g.
test_config_api.py's sm90-vs-sm100 cache-key test). The fast path
engages only while both functions are identity-equal to the originals
captured at import time; patched functions route to the original
uncached path, so the cache can never be poisoned by, nor serve stale
values to, arch-simulation tests.
An index-less torch.device("cuda") refers to the *current* device,
which can change between calls -- it is resolved per call via
torch.cuda.current_device() (cheap) before hitting the per-index cache.
Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):
```
n_args | before | after | delta
-------+----------+----------+-------
2 | 24.14 us | 18.78 us | -22%
8 | 33.05 us | 27.77 us | -16%
16 | 43.56 us | 36.90 us | -15%
```
Benchmark script (used for this and the follow-up launch-overhead
commits; wall-clock around the Python call deliberately captures
CPU-side dispatch cost, which CUDA-event timing excludes -- the kernel
is tiny, so the loop is CPU-bound and measures per-call issue cost):
```python
import os, sys
os.environ.setdefault("HELION_AUTOTUNE_EFFORT", "none")
import time, torch, helion, helion.language as hl
# Write kernels to a real module file: helion needs
# inspect.getsource, so exec()'d kernels don't work.
mod_src = ["import torch, helion, helion.language as hl\n"]
for n_args in (2, 8, 16):
src_args = ", ".join(f"t{i}" for i in range(n_args))
mod_src.append(f'''
@helion.kernel(config=helion.Config(block_sizes=[128], num_warps=4, num_stages=2))
def k{n_args}({src_args}):
out = torch.empty_like(t0)
for tile in hl.tile(t0.size(0)):
out[tile] = {' + '.join(f't{i}[tile]' for i in range(n_args))}
return out
''')
open("/tmp/bench_kernels.py", "w").write("\n".join(mod_src))
sys.path.insert(0, "/tmp")
import bench_kernels
def bench(n_args, n=20000):
k = getattr(bench_kernels, f"k{n_args}")
ts = [torch.randn(4096, device="cuda") for _ in range(n_args)]
for _ in range(50): # warmup: compile + prime launcher
k(*ts)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n):
k(*ts)
dt = (time.perf_counter() - t0) / n * 1e6
torch.cuda.synchronize()
return dt
for n_args in (2, 8, 16):
print(f"n_args={n_args:3d}: {bench(n_args):6.2f} us/call")
```
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
stack-info: PR: #2746, branch: yushangdi/stack/281 parent 037efd0 commit 26cd6c6
1 file changed
Lines changed: 52 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
347 | 347 | | |
348 | 348 | | |
349 | 349 | | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
350 | 379 | | |
351 | 380 | | |
352 | 381 | | |
353 | 382 | | |
354 | 383 | | |
355 | 384 | | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
356 | 408 | | |
357 | 409 | | |
358 | 410 | | |
| |||
0 commit comments