Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions helion/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,42 +347,69 @@ def supports_tensor_descriptor() -> bool:
return _supports_tensor_descriptor()


# Per-device-index cache for ``target_device_capability``. A physical
# device's compute capability cannot change within a process (torch itself
# caches ``device_count`` behind ``is_available``), but the query costs
# ~2.5us and sits on the per-call kernel dispatch path via
# ``_device_specialization_key``. Entries are only read/written while
# ``torch.cuda.is_available`` / ``torch.cuda.get_device_capability`` are
# the original functions — tests patch them to simulate other
# architectures, and the identity check below routes patched calls to the
# uncached path.
#
# NB: a plain ``@functools.cache`` on this function is wrong on two counts,
# which is why the cache is hand-rolled:
# 1. It keys on the ``device`` argument, not on whether the torch
# functions are patched, so it cannot fall back to a live query.
# Tests that patch ``get_device_capability`` to return (9, 0) then
# (10, 0) for the same device (e.g. test_config_api's sm90-vs-sm100
# cache-key test) would get the first cached value both times.
# 2. ``device=None`` means "the current device", which can change via
# ``torch.cuda.set_device``. functools.cache would freeze the first
# answer under the ``None`` key forever; here ``None`` is resolved to
# a concrete index per call before the cache lookup.
_REAL_CUDA_IS_AVAILABLE: Callable[[], bool] = torch.cuda.is_available
_REAL_CUDA_GET_CAPABILITY: Callable[..., tuple[int, int]] = (
torch.cuda.get_device_capability
)
_CUDA_CAPABILITY_CACHE: dict[int, tuple[int, int]] = {}
_CUDA_AVAILABLE: bool | None = None


def target_device_capability(
device: torch.device | None = None,
) -> tuple[int, int] | None:
"""Return CUDA compute capability, or None for non-CUDA/unavailable targets.

This sits on the per-call kernel dispatch path via
``_device_specialization_key`` and its torch.cuda queries cost ~2.7us,
so the result is memoized per device index in
``_target_device_capability``. Like ``is_hip`` / ``_is_hip``, the cache
lives in the private helper and this public wrapper is the seam tests
patch to simulate other architectures.

A concrete device index goes straight to the cached helper, so the
steady-state hot path makes no torch.cuda calls at all. ``device=None``
means the *current* device, which moves with ``torch.cuda.set_device``;
it is resolved to a concrete index per call (after an availability
check) so it is never frozen under a single cache key.
"""
"""Return CUDA compute capability, or None for non-CUDA/unavailable targets."""
if device is not None and device.type != "cuda":
return None
if device is not None and device.index is not None:
return _target_device_capability(device.index)
if not torch.cuda.is_available():
return None
return _target_device_capability(torch.cuda.current_device())


@functools.cache
def _target_device_capability(index: int) -> tuple[int, int] | None:
# A physical device's compute capability cannot change within a process,
# so memoize per device index; the availability check is inside the
# cache (like ``_is_hip``) so the hot path skips it once warm. Patched
# in tests via the public ``target_device_capability`` wrapper above.
if (
torch.cuda.is_available is _REAL_CUDA_IS_AVAILABLE
and torch.cuda.get_device_capability is _REAL_CUDA_GET_CAPABILITY
):
global _CUDA_AVAILABLE
if _CUDA_AVAILABLE is not True:
# Only cache the True result; an unavailable runtime has no GPU
# launches to speed up, so keep re-querying in that case.
if not _REAL_CUDA_IS_AVAILABLE():
return None
_CUDA_AVAILABLE = True
index = device.index if device is not None else None
if index is None:
# An index-less device refers to the *current* device, which can
# change between calls — resolve it each time (cheap), then hit
# the per-index cache.
index = torch.cuda.current_device()
capability = _CUDA_CAPABILITY_CACHE.get(index)
if capability is None:
_CUDA_CAPABILITY_CACHE[index] = capability = _REAL_CUDA_GET_CAPABILITY(
index
)
return capability
if not torch.cuda.is_available():
return None
return torch.cuda.get_device_capability(index)
if device is None:
return torch.cuda.get_device_capability(torch.cuda.current_device())
return torch.cuda.get_device_capability(device)


def min_dot_size(
Expand Down
11 changes: 11 additions & 0 deletions helion/_compile_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,17 @@ def measure(name: str) -> contextlib.AbstractContextManager[None]:
return _MeasureContext(name, get_tracker())


def is_enabled() -> bool:
"""Whether compile-time measurement is active.

Lets hot paths skip even entering a (disabled) ``measure()`` context
manager — the ``with`` protocol alone costs ~115ns/call, which is
significant on the per-call kernel dispatch path. ``enable()`` flips
this at runtime, so read it fresh rather than caching.
"""
return _enabled


def enable() -> None:
"""Enable compile-time measurement after this module has been imported."""
global _enabled
Expand Down
54 changes: 39 additions & 15 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,17 @@ def full_expr(
f"tl.full([{', '.join(shape_dims)}], {value_expr}, {self.dtype_str(dtype)})"
)

def launcher_keyword_args(self, config: Config, *, has_barrier: bool) -> list[str]:
def launcher_runtime_kwargs(
self, config: Config, *, has_barrier: bool
) -> dict[str, object]:
"""Return the launcher kwargs as a ``dict`` of runtime values.

Used by both :meth:`launcher_keyword_args` (which formats the dict
as source strings for codegen) and
:func:`helion.runtime.build_fast_launcher` (which bakes the dict
into a closure at :meth:`BoundKernel.set_config` time so the hot
launch path doesn't have to allocate a per-call kwargs dict).
"""
from .._compat import supports_maxnreg

# Workaround for triton bug: warp_specialize requires at least 4 warps
Expand All @@ -1325,31 +1335,45 @@ def launcher_keyword_args(self, config: Config, *, has_barrier: bool) -> list[st
if any(config.range_warp_specializes):
num_warps = max(4, num_warps)

args = [
f"num_warps={num_warps}",
f"num_stages={config.num_stages}",
*(["launch_cooperative_grid=True"] if has_barrier else []),
] + [
f"{x.removeprefix('_triton_config_')}={config[x]}"
for x in config
if x.startswith("_triton_config_")
]
kwargs: dict[str, object] = {
"num_warps": num_warps,
"num_stages": config.num_stages,
}
if has_barrier:
kwargs["launch_cooperative_grid"] = True
for x in config:
if x.startswith("_triton_config_"):
kwargs[x.removeprefix("_triton_config_")] = config[x]

from ..autotuner.config_spec import _get_backend_tunable_keys

for key in _get_backend_tunable_keys():
if key in config:
args.append(f"{key}={config[key]!r}")
kwargs[key] = config[key]

if "maxnreg" in config and config["maxnreg"] is not None and supports_maxnreg():
args.append(f"maxnreg={config['maxnreg']}")
kwargs["maxnreg"] = config["maxnreg"]

advanced_controls_file = config.advanced_controls_file
if advanced_controls_file:
ptx_option = f"--apply-controls {advanced_controls_file}"
args.append(f"ptx_options={ptx_option!r}")
kwargs["ptx_options"] = f"--apply-controls {advanced_controls_file}"

return kwargs

return args
def launcher_keyword_args(self, config: Config, *, has_barrier: bool) -> list[str]:
from ..autotuner.config_spec import _get_backend_tunable_keys

backend_tunable_keys = _get_backend_tunable_keys()
kwargs = self.launcher_runtime_kwargs(config, has_barrier=has_barrier)
# Backend tunable keys (typically strings) and ptx_options use repr;
# everything else (num_warps, num_stages, ints, bools) renders plain.
out: list[str] = []
for k, v in kwargs.items():
if k in backend_tunable_keys or k == "ptx_options":
out.append(f"{k}={v!r}")
else:
out.append(f"{k}={v}")
return out

def grid_barrier_stmt(self, sem_arg: str) -> str:
return f"triton_helpers.x_grid_barrier({sem_arg})"
Expand Down
37 changes: 3 additions & 34 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from .._compiler.cute.strategies import tcgen05_explicit_d_store_tile_expr
from .._compiler.cute.strategies import tcgen05_smem_layout_expr
from .._utils import triton_is_available
from ._fast_launcher import _FastLauncher as _FastLauncher
from ._fast_launcher import build_fast_launcher as build_fast_launcher
from ._fast_launcher import default_launcher as default_launcher
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
Expand Down Expand Up @@ -159,40 +162,6 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
return max(available_sms - reserved_sms, 1)


def default_launcher(
triton_kernel: object,
grid: tuple[int, ...],
*args: object,
num_warps: int,
num_stages: int,
ptx_options: str | None = None,
launch_cooperative_grid: bool = False,
**kwargs: dict,
) -> object:
"""Default launcher function that executes the kernel immediately."""
# For both CUDA and MTIA, use the same kernel execution
run_kwargs: dict = {
"grid": grid,
"warmup": False,
"num_warps": num_warps,
"num_stages": num_stages,
"launch_cooperative_grid": launch_cooperative_grid,
**kwargs,
}
if ptx_options is not None:
run_kwargs["ptx_options"] = ptx_options
try:
return triton_kernel.run( # type: ignore[union-attr]
*args,
**run_kwargs,
)
except Exception as error:
message = str(error)
if "Cannot make_shape_compatible: incompatible dimensions" in message:
raise exc.ShapeMismatch("kernel operands", message) from error
raise


def _pallas_make_block_spec(
pl: object,
jnp: object,
Expand Down
Loading
Loading