Add a SymInt-free tensor specialization key for exact torch.Tensor args#2748
Merged
Conversation
79b4c3b to
c9e5d81
Compare
f72bff3 to
7703235
Compare
This was referenced Jun 10, 2026
c9e5d81 to
924e16c
Compare
9eef756 to
6d22101
Compare
bf014f2 to
9cd7217
Compare
9cd7217 to
ff520f2
Compare
ff520f2 to
78060ed
Compare
78060ed to
f9c2756
Compare
f9c2756 to
4a78314
Compare
facf399 to
1693556
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
_tensor_key wraps every size and stride element in _hashable_dim to normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That wrap exists only for symbolic shapes, which appear on FakeTensors during tracing -- yet concrete tensors paid for it on every kernel call: two Python-level loops (sizes + strides) with an isinstance check per dimension, plus rebuilding each as a fresh tuple. For the default static_shapes=True path this was the bulk of per-tensor key extraction (~0.6us each, x number of tensor args, every call). Specialization-extractor dispatch is by exact type (_specialization_extractors.get(type(obj))), so the torch.Tensor and torch.nn.Parameter entries are only ever hit by objects whose sizes/strides are guaranteed concrete ints. Route those to a new _concrete_tensor_key that uses obj.size() (a torch.Size) and obj.stride() directly as key components. Both are tuple subclasses: they hash and compare identically to plain int tuples, so keys produced by the fast path and the SymInt path for the same concrete shape are interchangeable -- no cache invalidation, and the specialization axes (dtype, shape, stride, _dynamo_static_indices, int32/int64 index width, dynamic-shape buckets) are bit-identical. Anything that can carry SymInts keeps the safe path: * FakeTensor has its own dispatch entry -> _tensor_key (unchanged) * torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the exact-type dict and hit the isinstance fallback in _specialization_key, which now routes to _tensor_key explicitly (This is the same optimization as the Python-only part of upstream PR #2611, independently arrived at from profiling.) New test/test_tensor_key_fast_path.py pins down the key equivalence (fast vs wrapped key hash/compare equal under static and dynamic shapes, incl. strided tensors), the dispatch routing (concrete -> _concrete_tensor_key; FakeTensor and the subclass fallback -> _tensor_key), that a torch.Tensor subclass takes the SymInt-safe path and still shares a BoundKernel with a plain tensor of the same shape, and that bind() caching/distinguishing is unchanged. Also verified against test_misc, test_specialize, test_torch_compile (the torch.compile suite exercises the FakeTensor/SymInt path). Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters): ``` n_args | baseline | prev commit | this commit -------+----------+-------------+------------ 2 | 24.14 us | 18.01 us | 17.25 us 8 | 33.05 us | 25.57 us | 24.50 us 16 | 43.56 us | 36.41 us | 32.16 us ``` Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> stack-info: PR: #2748, branch: yushangdi/stack/30
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
_tensor_key wraps every size and stride element in _hashable_dim to normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That wrap exists only for symbolic shapes, which appear on FakeTensors during tracing -- yet concrete tensors paid for it on every kernel call: two Python-level loops (sizes + strides) with an isinstance check per dimension, plus rebuilding each as a fresh tuple. For the default static_shapes=True path this was the bulk of per-tensor key extraction (~0.6us each, x number of tensor args, every call). Specialization-extractor dispatch is by exact type (_specialization_extractors.get(type(obj))), so the torch.Tensor and torch.nn.Parameter entries are only ever hit by objects whose sizes/strides are guaranteed concrete ints. Route those to a new _concrete_tensor_key that uses obj.size() (a torch.Size) and obj.stride() directly as key components. Both are tuple subclasses: they hash and compare identically to plain int tuples, so keys produced by the fast path and the SymInt path for the same concrete shape are interchangeable -- no cache invalidation, and the specialization axes (dtype, shape, stride, _dynamo_static_indices, int32/int64 index width, dynamic-shape buckets) are bit-identical. Anything that can carry SymInts keeps the safe path: * FakeTensor has its own dispatch entry -> _tensor_key (unchanged) * torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the exact-type dict and hit the isinstance fallback in _specialization_key, which now routes to _tensor_key explicitly (This is the same optimization as the Python-only part of upstream PR #2611, independently arrived at from profiling.) New test/test_tensor_key_fast_path.py pins down the key equivalence (fast vs wrapped key hash/compare equal under static and dynamic shapes, incl. strided tensors), the dispatch routing (concrete -> _concrete_tensor_key; FakeTensor and the subclass fallback -> _tensor_key), that a torch.Tensor subclass takes the SymInt-safe path and still shares a BoundKernel with a plain tensor of the same shape, and that bind() caching/distinguishing is unchanged. Also verified against test_misc, test_specialize, test_torch_compile (the torch.compile suite exercises the FakeTensor/SymInt path). Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters): ``` n_args | baseline | prev commit | this commit -------+----------+-------------+------------ 2 | 24.14 us | 18.01 us | 17.25 us 8 | 33.05 us | 25.57 us | 24.50 us 16 | 43.56 us | 36.41 us | 32.16 us ``` Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> stack-info: PR: #2748, branch: yushangdi/stack/30
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
Add a SymInt-free tensor specialization key for exact torch.Tensor args
_tensor_key wraps every size and stride element in _hashable_dim to
normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That
wrap exists only for symbolic shapes, which appear on FakeTensors
during tracing -- yet concrete tensors paid for it on every kernel
call: two Python-level loops (sizes + strides) with an isinstance
check per dimension, plus rebuilding each as a fresh tuple. For the
default static_shapes=True path this was the bulk of per-tensor key
extraction (~0.6us each, x number of tensor args, every call).
Specialization-extractor dispatch is by exact type
(_specialization_extractors.get(type(obj))), so the torch.Tensor and
torch.nn.Parameter entries are only ever hit by objects whose
sizes/strides are guaranteed concrete ints. Route those to a new
_concrete_tensor_key that uses obj.size() (a torch.Size) and
obj.stride() directly as key components. Both are tuple subclasses:
they hash and compare identically to plain int tuples, so keys
produced by the fast path and the SymInt path for the same concrete
shape are interchangeable -- no cache invalidation, and the
specialization axes (dtype, shape, stride, _dynamo_static_indices,
int32/int64 index width, dynamic-shape buckets) are bit-identical.
Anything that can carry SymInts keeps the safe path:
exact-type dict and hit the isinstance fallback in
_specialization_key, which now routes to _tensor_key explicitly
(This is the same optimization as the Python-only part of upstream
PR #2611, independently arrived at from profiling.)
New test/test_tensor_key_fast_path.py pins down the key equivalence
(fast vs wrapped key hash/compare equal under static and dynamic
shapes, incl. strided tensors), the dispatch routing (concrete ->
_concrete_tensor_key; FakeTensor and the subclass fallback ->
_tensor_key), that a torch.Tensor subclass takes the SymInt-safe path
and still shares a BoundKernel with a plain tensor of the same shape,
and that bind() caching/distinguishing is unchanged. Also verified
against test_misc, test_specialize, test_torch_compile (the
torch.compile suite exercises the FakeTensor/SymInt path).
Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):
Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com