Skip to content

Add a SymInt-free tensor specialization key for exact torch.Tensor args#2748

Merged
yushangdi merged 0 commit into
yushangdi/stack/33from
yushangdi/stack/30
Jun 11, 2026
Merged

Add a SymInt-free tensor specialization key for exact torch.Tensor args#2748
yushangdi merged 0 commit into
yushangdi/stack/33from
yushangdi/stack/30

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


Add a SymInt-free tensor specialization key for exact torch.Tensor args

_tensor_key wraps every size and stride element in _hashable_dim to
normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That
wrap exists only for symbolic shapes, which appear on FakeTensors
during tracing -- yet concrete tensors paid for it on every kernel
call: two Python-level loops (sizes + strides) with an isinstance
check per dimension, plus rebuilding each as a fresh tuple. For the
default static_shapes=True path this was the bulk of per-tensor key
extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type
(_specialization_extractors.get(type(obj))), so the torch.Tensor and
torch.nn.Parameter entries are only ever hit by objects whose
sizes/strides are guaranteed concrete ints. Route those to a new
_concrete_tensor_key that uses obj.size() (a torch.Size) and
obj.stride() directly as key components. Both are tuple subclasses:
they hash and compare identically to plain int tuples, so keys
produced by the fast path and the SymInt path for the same concrete
shape are interchangeable -- no cache invalidation, and the
specialization axes (dtype, shape, stride, _dynamo_static_indices,
int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:

  • FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
  • torch.Tensor subclasses (e.g. the JAX-export adapter) miss the
    exact-type dict and hit the isinstance fallback in
    _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream
PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence
(fast vs wrapped key hash/compare equal under static and dynamic
shapes, incl. strided tensors), the dispatch routing (concrete ->
_concrete_tensor_key; FakeTensor and the subclass fallback ->
_tensor_key), that a torch.Tensor subclass takes the SymInt-safe path
and still shares a BoundKernel with a plain tensor of the same shape,
and that bind() caching/distinguishing is unchanged. Also verified
against test_misc, test_specialize, test_torch_compile (the
torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    18.01 us |    17.25 us
       8 | 33.05 us |    25.57 us |    24.50 us
      16 | 43.56 us |    36.41 us |    32.16 us

Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com

@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from 79b4c3b to c9e5d81 Compare June 10, 2026 23:39
@yushangdi yushangdi force-pushed the yushangdi/stack/29 branch from f72bff3 to 7703235 Compare June 10, 2026 23:39
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@yushangdi yushangdi changed the base branch from yushangdi/stack/29 to main June 10, 2026 23:43
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from c9e5d81 to 924e16c Compare June 10, 2026 23:43
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/29 June 10, 2026 23:43
@yushangdi yushangdi changed the base branch from yushangdi/stack/29 to main June 11, 2026 00:37
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch 2 times, most recently from 9eef756 to 6d22101 Compare June 11, 2026 00:37
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/32 June 11, 2026 00:37
@yushangdi yushangdi changed the base branch from yushangdi/stack/32 to main June 11, 2026 00:45
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch 2 times, most recently from bf014f2 to 9cd7217 Compare June 11, 2026 00:45
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/33 June 11, 2026 00:45
@yushangdi yushangdi changed the base branch from yushangdi/stack/33 to main June 11, 2026 01:15
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from 9cd7217 to ff520f2 Compare June 11, 2026 01:16
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/33 June 11, 2026 01:16
@yushangdi yushangdi changed the base branch from yushangdi/stack/33 to main June 11, 2026 01:55
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from ff520f2 to 78060ed Compare June 11, 2026 01:55
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/33 June 11, 2026 01:55
@yushangdi yushangdi changed the base branch from yushangdi/stack/33 to main June 11, 2026 02:02
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from 78060ed to f9c2756 Compare June 11, 2026 02:02
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/33 June 11, 2026 02:02
@yushangdi yushangdi force-pushed the yushangdi/stack/30 branch from f9c2756 to 4a78314 Compare June 11, 2026 17:09
@yushangdi yushangdi merged commit 4a78314 into yushangdi/stack/33 Jun 11, 2026
1 check passed
@yushangdi yushangdi force-pushed the yushangdi/stack/33 branch from facf399 to 1693556 Compare June 11, 2026 17:10
yushangdi added a commit that referenced this pull request Jun 11, 2026
_tensor_key wraps every size and stride element in _hashable_dim to
normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That
wrap exists only for symbolic shapes, which appear on FakeTensors
during tracing -- yet concrete tensors paid for it on every kernel
call: two Python-level loops (sizes + strides) with an isinstance
check per dimension, plus rebuilding each as a fresh tuple. For the
default static_shapes=True path this was the bulk of per-tensor key
extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type
(_specialization_extractors.get(type(obj))), so the torch.Tensor and
torch.nn.Parameter entries are only ever hit by objects whose
sizes/strides are guaranteed concrete ints. Route those to a new
_concrete_tensor_key that uses obj.size() (a torch.Size) and
obj.stride() directly as key components. Both are tuple subclasses:
they hash and compare identically to plain int tuples, so keys
produced by the fast path and the SymInt path for the same concrete
shape are interchangeable -- no cache invalidation, and the
specialization axes (dtype, shape, stride, _dynamo_static_indices,
int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:
 * FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
 * torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the
   exact-type dict and hit the isinstance fallback in
   _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream
PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence
(fast vs wrapped key hash/compare equal under static and dynamic
shapes, incl. strided tensors), the dispatch routing (concrete ->
_concrete_tensor_key; FakeTensor and the subclass fallback ->
_tensor_key), that a torch.Tensor subclass takes the SymInt-safe path
and still shares a BoundKernel with a plain tensor of the same shape,
and that bind() caching/distinguishing is unchanged. Also verified
against test_misc, test_specialize, test_torch_compile (the
torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    18.01 us |    17.25 us
       8 | 33.05 us |    25.57 us |    24.50 us
      16 | 43.56 us |    36.41 us |    32.16 us
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2748, branch: yushangdi/stack/30
yushangdi added a commit that referenced this pull request Jun 11, 2026
_tensor_key wraps every size and stride element in _hashable_dim to
normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That
wrap exists only for symbolic shapes, which appear on FakeTensors
during tracing -- yet concrete tensors paid for it on every kernel
call: two Python-level loops (sizes + strides) with an isinstance
check per dimension, plus rebuilding each as a fresh tuple. For the
default static_shapes=True path this was the bulk of per-tensor key
extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type
(_specialization_extractors.get(type(obj))), so the torch.Tensor and
torch.nn.Parameter entries are only ever hit by objects whose
sizes/strides are guaranteed concrete ints. Route those to a new
_concrete_tensor_key that uses obj.size() (a torch.Size) and
obj.stride() directly as key components. Both are tuple subclasses:
they hash and compare identically to plain int tuples, so keys
produced by the fast path and the SymInt path for the same concrete
shape are interchangeable -- no cache invalidation, and the
specialization axes (dtype, shape, stride, _dynamo_static_indices,
int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:
 * FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
 * torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the
   exact-type dict and hit the isinstance fallback in
   _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream
PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence
(fast vs wrapped key hash/compare equal under static and dynamic
shapes, incl. strided tensors), the dispatch routing (concrete ->
_concrete_tensor_key; FakeTensor and the subclass fallback ->
_tensor_key), that a torch.Tensor subclass takes the SymInt-safe path
and still shares a BoundKernel with a plain tensor of the same shape,
and that bind() caching/distinguishing is unchanged. Also verified
against test_misc, test_specialize, test_torch_compile (the
torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel,
N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

```
  n_args | baseline | prev commit | this commit
  -------+----------+-------------+------------
       2 | 24.14 us |    18.01 us |    17.25 us
       8 | 33.05 us |    25.57 us |    24.50 us
      16 | 43.56 us |    36.41 us |    32.16 us
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2748, branch: yushangdi/stack/30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant