Skip to content

Add a SymInt-free tensor specialization key for exact torch.Tensor args#2759

Open
yushangdi wants to merge 1 commit into
mainfrom
yushangdi/stack/37
Open

Add a SymInt-free tensor specialization key for exact torch.Tensor args#2759
yushangdi wants to merge 1 commit into
mainfrom
yushangdi/stack/37

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


Add a SymInt-free tensor specialization key for exact torch.Tensor args

_tensor_key wraps every size and stride element in _hashable_dim to normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That wrap exists only for symbolic shapes, which appear on FakeTensors during tracing -- yet concrete tensors paid for it on every kernel call: two Python-level loops (sizes + strides) with an isinstance check per dimension, plus rebuilding each as a fresh tuple. For the default static_shapes=True path this was the bulk of per-tensor key extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type (_specialization_extractors.get(type(obj))), so the torch.Tensor and torch.nn.Parameter entries are only ever hit by objects whose sizes/strides are guaranteed concrete ints. Route those to a new _concrete_tensor_key that uses obj.size() (a torch.Size) and obj.stride() directly as key components. Both are tuple subclasses: they hash and compare identically to plain int tuples, so keys produced by the fast path and the SymInt path for the same concrete shape are interchangeable -- no cache invalidation, and the specialization axes (dtype, shape, stride, _dynamo_static_indices, int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:

  • FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
  • torch.Tensor subclasses (e.g. the JAX-export adapter) miss the exact-type dict and hit the isinstance fallback in _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence (fast vs wrapped key hash/compare equal under static and dynamic shapes, incl. strided tensors), the dispatch routing (concrete -> _concrete_tensor_key; FakeTensor and the subclass fallback -> _tensor_key), that a torch.Tensor subclass takes the SymInt-safe path and still shares a BoundKernel with a plain tensor of the same shape, and that bind() caching/distinguishing is unchanged.

Also verified against test_misc, test_specialize, test_torch_compile (the torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):

n_args | baseline | prev commit | this commit
-------+----------+-------------+------------
     2 |  24.14 us |  18.01 us  |  17.25 us
     8 |  33.05 us |  25.57 us  |  24.50 us
    16 |  43.56 us |  36.41 us  |  32.16 us

Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com

@yushangdi yushangdi force-pushed the yushangdi/stack/37 branch from 952d60c to 3f452d0 Compare June 11, 2026 17:52
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 11, 2026
@yushangdi yushangdi marked this pull request as draft June 11, 2026 18:47
@yushangdi yushangdi changed the base branch from yushangdi/stack/33 to main June 11, 2026 18:47
_tensor_key wraps every size and stride element in _hashable_dim to normalize torch.SymInt into a hashable (shape_env_id, expr) pair. That wrap exists only for symbolic shapes, which appear on FakeTensors during tracing -- yet concrete tensors paid for it on every kernel call: two Python-level loops (sizes + strides) with an isinstance check per dimension, plus rebuilding each as a fresh tuple. For the default static_shapes=True path this was the bulk of per-tensor key extraction (~0.6us each, x number of tensor args, every call).

Specialization-extractor dispatch is by exact type (_specialization_extractors.get(type(obj))), so the torch.Tensor and torch.nn.Parameter entries are only ever hit by objects whose sizes/strides are guaranteed concrete ints. Route those to a new _concrete_tensor_key that uses obj.size() (a torch.Size) and obj.stride() directly as key components. Both are tuple subclasses: they hash and compare identically to plain int tuples, so keys produced by the fast path and the SymInt path for the same concrete shape are interchangeable -- no cache invalidation, and the specialization axes (dtype, shape, stride, _dynamo_static_indices, int32/int64 index width, dynamic-shape buckets) are bit-identical.

Anything that can carry SymInts keeps the safe path:
* FakeTensor has its own dispatch entry -> _tensor_key (unchanged)
* torch.Tensor *subclasses* (e.g. the JAX-export adapter) miss the exact-type dict and hit the isinstance fallback in _specialization_key, which now routes to _tensor_key explicitly

(This is the same optimization as the Python-only part of upstream PR #2611, independently arrived at from profiling.)

New test/test_tensor_key_fast_path.py pins down the key equivalence (fast vs wrapped key hash/compare equal under static and dynamic shapes, incl. strided tensors), the dispatch routing (concrete -> _concrete_tensor_key; FakeTensor and the subclass fallback -> _tensor_key), that a torch.Tensor subclass takes the SymInt-safe path and still shares a BoundKernel with a plain tensor of the same shape, and that bind() caching/distinguishing is unchanged.

Also verified against test_misc, test_specialize, test_torch_compile (the torch.compile suite exercises the FakeTensor/SymInt path).

Benchmark (B200, end-to-end wall time per call, add-style kernel, N=4096 fp32, HELION_AUTOTUNE_EFFORT=none, steady state, 20k iters):
```
n_args | baseline | prev commit | this commit
-------+----------+-------------+------------
     2 |  24.14 us |  18.01 us  |  17.25 us
     8 |  33.05 us |  25.57 us  |  24.50 us
    16 |  43.56 us |  36.41 us  |  32.16 us
```

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2759, branch: yushangdi/stack/37
@yushangdi yushangdi force-pushed the yushangdi/stack/37 branch from 3f452d0 to d2b3a32 Compare June 11, 2026 22:03
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 22:08
@yushangdi yushangdi requested review from choijon5, jansel and oulgen June 11, 2026 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant