Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def __init__(
Config(**config) if isinstance(config, dict) else config
for config in configs or []
]
self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {}
# Keyed by ``(signature, extra_results)`` tuples — the tuple form of
# ``BoundKernelInMemoryCacheKey``, kept plain for cheap per-call
# construction in ``bind``.
self._bound_kernels: dict[Hashable, BoundKernel] = {}
self._specialize_extra: dict[
Hashable, list[Callable[[Sequence[object]], Hashable]]
] = {}
Expand Down Expand Up @@ -226,13 +229,15 @@ def __init__(

def _get_bound_kernel_cache_key(
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
) -> BoundKernelInMemoryCacheKey | None:
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

) -> tuple[Hashable, ...] | None:
# Plain tuples keep the per-call dispatch path free of dataclass
# construction; `_create_bound_kernel_cache_key` provides the
# `BoundKernelInMemoryCacheKey` form for the autotuner caches.
extra_fns = self._specialize_extra.get(signature)
if extra_fns is not None:
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
return BoundKernelInMemoryCacheKey(signature, extra_results)
if extra_fns:
return (signature, tuple([s(args) for s in extra_fns]))
return (signature, ())
return None

def _create_bound_kernel_cache_key(
Expand Down Expand Up @@ -278,9 +283,10 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
else:
bound_kernel = BoundKernel(self, args)
if cache_key is None:
cache_key = self._create_bound_kernel_cache_key(
full_key = self._create_bound_kernel_cache_key(
bound_kernel, args, signature
)
cache_key = (full_key.specialization_key, full_key.extra_results)
self._bound_kernels[cache_key] = bound_kernel
return bound_kernel

Expand Down
Loading