You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'd like to open this discussion to get some feedback on various project ideas that I have come across- to get a sense of whether these have impact for our current users and how they could be implemented. This is not a formal RFC. If you feel strongly about any of these and want to implement the feature, please open an RFC yourself, and feel free to link the discussion for reference.
Overview of Projects
This document outlines concrete improvement opportunities for torch-mlir,
ordered by estimated impact (highest first). Impact considers user-facing value,
how blocking the current state is, and what new capabilities are unlocked.
Note: This document was authored in a conversation with Claude Opus 4.6.
Critical broken feature — quantized deployment is a primary use case
3
Training Support
High
Medium
Opens entirely new use case, surprisingly feasible given 1
4
Remove PT1 Legacy Code
High
Low-Medium
Best effort-to-impact ratio, removes 63K lines of dead code
5
Stride and Memory Format
Medium-High
High
Important for correctness but requires deep architectural changes
6
Shape Inference
Medium
Medium-High
Current system works for most models, improvements are incremental
7
ONNX Maintainability
Medium
Medium
Internal quality, no new user-facing capabilities
1. Leverage PyTorch Decompositions to Reduce Op Surface Area
Impact: Very High | Effort: Medium (incremental)
This is the highest-impact project because it is foundational — it reduces
maintenance across the entire codebase, simplifies backend work, and directly
enables training support (section 3).
Problem
torch-mlir maintains ~230 decomposition patterns in DecomposeComplexOps.cpp
(~13,600 lines of C++). Meanwhile, PyTorch's torch._decomp.core_aten_decompositions()
provides ~1,000 decompositions that reduce complex ops down to ~250 "core aten"
primitives. torch-mlir currently uses only ~52 of these via a hand-curated list
in python/torch_mlir/extras/fx_decomp_util.py.
This means torch-mlir is re-implementing in C++ many decompositions that PyTorch
already provides in Python — and must maintain them as PyTorch evolves.
Architecture context
Today, both the FX and ONNX import paths produce "full" Torch dialect ops (e.g., aten.layer_norm, aten.softmax). DecomposeComplexOps is the single place
that reduces these to primitives, and LowerToBackendContract enforces that
everything was decomposed. Backends can exempt specific ops from decomposition
via the backendLegalOps mechanism (LowerToBackendContract.cpp:600-602):
This lets a backend say "keep addmm fused" even though a decomposition exists.
The multi-frontend complication
PyTorch-side decompositions via prog.run_decompositions() happen before MLIR
import, which creates asymmetry between frontends:
FX path: can apply PyTorch decompositions before import (the mechanism
already exists in fx.py)
ONNX path: TorchOnnxToTorch produces Torch ops directly and never goes
through torch.export, so PyTorch-side decompositions are unavailable
This means:
DecomposeComplexOps is still required for the ONNX path regardless
The backendLegalOps escape hatch cannot reach PyTorch-side decompositions —
if PyTorch decomposes addmm before import, a backend that wanted fused addmm has no way to prevent it
If the FX path pre-decomposes ops but the ONNX path doesn't, the two paths
produce different IR for the same model, complicating testing and backend
support
Possible approaches
A) Hybrid: PyTorch decompositions for "safe" ops only (recommended)
Use PyTorch-side decompositions on the FX path only for ops where no backend
would want the non-decomposed form — activations, loss functions, statistical
reductions, simple math. Keep DecomposeComplexOps as the canonical
decomposition layer for everything else.
Categories of MLIR-side patterns that could be eliminated this way:
Activation functions: celu, elu, glu, hardsigmoid, hardswish, leaky_relu, mish, prelu, relu6, selu, silu, softplus, softshrink, hardshrink, rrelu, etc.
Loss functions: binary_cross_entropy_with_logits, cross_entropy_loss, kl_div, l1_loss, mse_loss, poisson_nll_loss, soft_margin_loss
Math ops: addcdiv, addcmul, deg2rad, square, frac, lerp, logaddexp, copysign, etc.
Statistical: std, var, var_mean, mean
Data movement: channel_shuffle, pixel_shuffle, pixel_unshuffle, roll, rot90, tile, stack, column_stack
The FX decomposition table flips from opt-in to opt-out:
fromtorch._decompimportcore_aten_decompositionsdefget_decomposition_table():
decomps=core_aten_decompositions()
# Keep ops where backends benefit from the non-decomposed form,# or where backendLegalOps flexibility is neededforopinKEEP_NON_DECOMPOSED:
decomps.pop(op, None)
returndecomps
The corresponding DecomposeComplexOps patterns and markDecomposedOpsAsIllegal
entries can be deleted incrementally as PyTorch-side decomps prove stable. The
ONNX path continues to rely on DecomposeComplexOps for these same ops — the
C++ patterns are only removable once both paths are covered.
Pro: Incremental, preserves backendLegalOps for ops that need it, reduces C++
maintenance for the common case.
Con: Two decomposition sites to reason about; ONNX path doesn't benefit.
B) Define a "core" op set and have all frontends target it
Define an explicit set of ~250 primitive torch ops that backends must implement.
FX path: use core_aten_decompositions() to reach that set before import
ONNX path: have TorchOnnxToTorch lower ONNX ops directly to primitives,
skipping intermediate "full" torch ops
DecomposeComplexOps becomes a thin pass for MLIR-specific structural
transforms only
Pro: One op set, dramatically reduced C++ code, PyTorch maintains decompositions.
Con: Requires significant rework of TorchOnnxToTorch to target primitives
instead of the full torch op set. Loses backendLegalOps flexibility for
PyTorch-decomposed ops.
C) Keep DecomposeComplexOps as the single source of truth (status quo)
Both paths import into the "full" Torch op set. DecomposeComplexOps is the one
place that defines the decomposition contract. PyTorch-side decompositions are
used only for ops that torch-mlir doesn't handle at all (the current 52-op
list).
Pro: One decomposition contract, works for all frontends, backendLegalOps
works everywhere.
Con: Maintaining ~230 C++ decomposition patterns indefinitely.
What stays in DecomposeComplexOps regardless
~100-130 patterns that are structural to torch-mlir's type system and have no
PyTorch equivalent:
Value semantics conversions (AtenContiguousOp, AtenCopyOp, AtenExpandOp)
Shape/view operations specific to MLIR lowering
Prims* op decompositions
Patterns that decompose into ops the backends handle more efficiently than the
core aten decomposed form
Rollout plan (for approach A)
Script an overlap analysis: match DecomposeComplexOps patterns against core_aten_decompositions() entries
Identify "safe" ops — those where no backend uses backendLegalOps to keep
them, and PyTorch's decomposed form targets ops backends already handle
Batch-migrate safe categories (activations, losses, math ops first)
Run e2e test suite after each batch
Delete corresponding C++ patterns as PyTorch-side decomps prove stable
Risks
PyTorch decompositions may produce ops that are harder for backends to lower
than the original (e.g., decomposing addmm into mm + add loses fusion
opportunities) — mitigated by the opt-out list
The decomposed op surface changes with PyTorch versions — mostly good (free
improvements) but could introduce regressions
ONNX path still needs DecomposeComplexOps for all migrated ops, so C++
patterns can only be deleted once the ONNX path is also addressed
Significant test churn during migration
2. First-Class PT2E Quantization Support
Impact: Very High | Effort: Medium-High
Quantized inference is a primary deployment use case, and the modern PyTorch
quantization path (PT2E) is effectively broken in torch-mlir. The current shim
converts PT2E ops backwards into a legacy representation, losing information and
limiting coverage.
Background: What PT2E quantization actually produces
PT2E quantization (via torch.ao.quantization / torchao.quantization.pt2e)
exports graphs where quantization is expressed as explicit quantized_decomposed
ops surrounding normal float compute. There are no quantized compute ops —
the compiler is expected to fuse dequant → float_op → quant into integer
arithmetic.
A quantized linear+relu looks like this in the exported graph:
Note: PT2E quantization is migrating from torch.ao.quantization to torchao
(torchao.quantization.pt2e), but the underlying quantized_decomposed op
definitions and graph representation remain the same.
Current state in torch-mlir
torch-mlir's quantization is built on a legacy representation: !torch.qint8
types, aten.quantize_per_tensor (which produces quantized-typed tensors), and aten._make_per_tensor_quantized_tensor + aten.dequantize pairs. A shim pass
(MatchQuantizedCustomOps) converts 3 of the 19 PT2E ops into this old
representation. This is backwards — it forces the modern representation through a
legacy bottleneck.
What first-class PT2E support looks like
The design principle: keep the PT2E representation as the canonical form
throughout the pipeline. Don't convert to old-style QDQ. The old-style
representation (!torch.qint8, aten.quantize_per_tensor, aten._make_per_tensor_quantized_tensor) should be treated as legacy/ONNX-only.
Stage 1: FX Import
The FX importer needs to handle quantized_decomposed ops natively. Today they
arrive as torch.operator strings because the ops aren't registered in the Torch
dialect. Two options:
Option A: Define first-class ODS ops. Add all 19 quantized_decomposed ops
to the Torch dialect with proper type constraints, shape inference, and
verifiers. This is the cleanest but requires maintaining ODS definitions in sync
with PyTorch.
Option B: Import as torch.operator but with proper type annotations. Keep
the generic torch.operator representation but ensure the FX importer sets
correct result types (integer tensors, not float). This is less work but provides
no static verification.
Option A is preferred for a first-class experience.
The is_quantized check in fx_importer.py:1220 should be handled, but in
practice PT2E graphs use plain integer tensor types between q/dq ops — they do not use PyTorch's quantized tensor subclass. So the existing check may not
even fire for PT2E models. The real issue is that quantized_decomposed ops
aren't recognized.
With the PT2E ops in the IR, the next step is pattern-matching dq → op → q
sequences and deciding what to do with them. This replaces the current FuseQuantizedOps pass with something more general.
The key design question: should we fuse into quantized compute ops, or lower
q/dq directly to backends?
Approach: Don't fuse in the Torch dialect. Instead, keep the dq → float_op → q pattern intact and let each backend handle fusion during lowering. This is
simpler and more flexible:
TorchToLinalg: Pattern-match dq → linalg.matmul → q and emit linalg.quantized_matmul or integer linalg.generic with scale/zp math
TorchToTosa: Map q/dq ops directly to TOSA's native quantization (integer
conv2d/matmul with explicit scale attributes)
TorchToStablehlo: Map to stablehlo.uniform_quantize / stablehlo.uniform_dequantize
The Torch dialect passes should focus on propagation and simplification, not
fusion:
Constant folding: If weights are constant float, fold dq(q(const)) into
a quantized constant directly
Dead q/dq elimination: Remove q → dq pairs that cancel out
Scale propagation: Track scale/zp through commuting ops (reshape,
transpose, pad, slice) so backends see them at the right point
Stage 3: Backend Lowering
Each backend has native quantization support with different representations.
The lowering should pattern-match the PT2E ops directly:
TorchToLinalg:
// Pattern: dq(input, s_in, zp_in) → aten.mm(_, dq(weight, s_w, zp_w)) → q(_, s_out, zp_out)
// Lowers to: integer linalg.generic with accumulation in i32,
// followed by rescale: out = (acc * s_in * s_w / s_out) + zp_out
For ops where integer fusion isn't available (or isn't profitable), the backend
can simply lower dequantize_per_tensor to (int_to_float(x) - zp) * scale
and quantize_per_tensor to clamp(round(x / scale) + zp, qmin, qmax) as
scalar arithmetic. This is always correct, just not as fast as fused integer
compute.
TorchToTosa:
TOSA natively supports quantized integer tensors with scale/zp as op attributes.
Lower dq → conv → q directly to tosa.conv2d with integer inputs and the
quantization parameters as attributes.
TorchToStablehlo:
StableHLO has uniform_quantize / uniform_dequantize ops and supports
quantized types via MLIR's quant dialect. Lower quantized_decomposed ops
directly to these.
Stage 4: Deprecate old-style QDQ
Once PT2E is the canonical path:
The ONNX path's QuantizeLinear/DequantizeLinear should lower to quantized_decomposed ops instead of old-style aten.quantize_per_tensor
Remove MatchQuantizedCustomOps (no longer needed — PT2E ops are first-class)
Remove old-style quantized types (!torch.qint8, !torch.quint8) and ops
(aten._make_per_tensor_quantized_tensor, aten.int_repr) from the required
backend support surface
Remove FuseQuantizedOps (fusion happens at backend lowering)
Per-tensor vs per-channel vs per-token
The design must handle all granularities from the start:
Granularity
Scale shape
Use case
per-tensor
scalar
Activations, simple weight quant
per-channel
[out_channels]
Weight quantization (standard)
per-token
[batch, seq_len, 1]
Dynamic activation quant (LLMs)
per-channel-group
[channels, groups]
GPTQ/AWQ weight quant (LLMs)
Per-channel and per-token are critical for LLM quantization (GPTQ, AWQ,
SmoothQuant). These should not be an afterthought.
What this unlocks
End-to-end quantized model compilation through the modern PyTorch path
All quantization granularities: per-tensor, per-channel, per-token,
per-channel-group
Clean separation: quantization semantics in the graph, fusion decisions in
backends
Alignment with PyTorch/torchao direction — the old eager-mode quantization
APIs are deprecated
Foundation for future quantization schemes (e.g., float8, microscaling)
3. Training Support
Impact: High | Effort: Medium
Training opens an entirely new use case for torch-mlir. The key insight is that
PyTorch's decomposition machinery (section 1) eliminates almost all backward ops
before they reach torch-mlir, making training support far more feasible than
commonly assumed.
The surprising finding: training is closer than expected
A common assumption is that training requires supporting dozens of backward ops
and building autograd infrastructure in MLIR. In practice, PyTorch's own
decomposition machinery eliminates almost all backward ops before they would
reach torch-mlir.
Using make_fx + torch.autograd.grad + core_aten_decompositions(), a full
training step (forward + backward + gradient computation) can be traced into a flat FX graph of primitive ops. Most backward ops decompose entirely into
ordinary forward ops:
Layer
Backward ops after decomposition
Linear / matmul
None — decomposes to mm, t, mul, etc.
Embedding
None — fully decomposed
LayerNorm
None — fully decomposed
BatchNorm
None — fully decomposed
Attention (SDPA)
None — fully decomposed
ReLU, GELU, SiLU, etc.
None — decomposes to where, le, mul
Conv2d
convolution_backward (irreducible)
MaxPool2d
max_pool2d_with_indices_backward (irreducible)
AvgPool2d
avg_pool2d_backward (irreducible)
A transformer training step (attention + layernorm + loss + grads) produced a
graph with zero backward ops — just 23 ordinary ops like mm, bmm, view, mul, sub, where, native_layer_norm, _softmax. These are all ops
torch-mlir already lowers.
PyTorch provides 104 backward op decompositions in core_aten_decompositions(). Only a handful of backward ops are irreducible: convolution_backward, max_pool2d_with_indices_backward, avg_pool2d_backward, and a few others. torch-mlir already has lowerings for convolution_backward (TorchToLinalg) and max_pool2d_with_indices_backward
(TorchToTMTensor).
What a training step looks like
# User writes this:deftrain_step(params, x, target):
out=model_forward(params, x)
loss=loss_fn(out, target)
grads=torch.autograd.grad(loss, params)
returngrads# or updated params# Traced with make_fx:fromtorch.fx.experimental.proxy_tensorimportmake_fxfromtorch._decompimportcore_aten_decompositionsfx_g=make_fx(train_step, decomposition_table=core_aten_decompositions())(
params, x, target
)
# fx_g is a flat graph of ~20-30 primitive ops, no backward ops
The resulting fx_g is a GraphModule containing only primitive forward ops
that torch-mlir already knows how to lower.
Blockers
1. No make_fx import path
The FX importer (fx_importer.py) expects torch.export.ExportedProgram. make_fx produces a torch.fx.GraphModule — a different format. Options:
Wrap GraphModule into ExportedProgram: PyTorch has internal utilities
for this, but they're not stable public API
Support GraphModule import directly: Add a parallel import path in the
FX importer that handles GraphModule without requiring ExportedProgram
metadata (graph signatures, range constraints, etc.)
Use aot_export_joint_simple: This functorch API traces joint
forward+backward and may produce output closer to what the importer expects
2. torch.export is inference-only
torch.export.export() does not capture backward graphs. There is an export_for_training() API, but it is deprecated and produces a forward-only
graph with autograd metadata — not a traced backward pass. The standard way to
get a joint forward+backward graph is make_fx or aot_autograd, which are
lower-level functorch APIs.
3. Parameter mutation / weight updates
A full training loop includes params = params - lr * grads. This is trivial
arithmetic, but the pipeline needs to handle it:
The training step function should take parameters as explicit inputs and return
updated parameters (functional style)
The caller manages parameter state between steps
This is already how make_fx traces work — all state is explicit
4. Missing backward op lowerings (small set)
Only 2-3 irreducible backward ops need lowering support that doesn't fully exist:
avg_pool2d_backward — no lowering currently
_scaled_dot_product_flash_attention_for_cpu_backward — no lowering (but SDPA
decomposes fully anyway with core_aten_decompositions)
convolution_backward — already has TorchToLinalg lowering and a
decomposition in DecomposeComplexOps.cpp
5. Dynamic shapes in training
Training graphs have more dynamic shape requirements than inference:
Variable batch sizes
Sequence length variation (padding, packing)
Dropout masks (random shapes)
Gradient accumulation across steps
Approach
Phase 1: Import make_fx training graphs
Add GraphModule import support to the FX importer (or a thin wrapper that
converts GraphModule → ExportedProgram)
Apply core_aten_decompositions() during tracing to eliminate backward ops
Verify that the resulting flat graph imports and lowers through existing
backend pipelines
Test with simple models: linear regression, MLP, small transformer
Phase 2: Fill remaining op gaps
Add avg_pool2d_backward lowering to TorchToLinalg
Verify convolution_backward works end-to-end (the lowering exists but may
have edge cases — some conv backward tests are marked as hanging in xfail
sets)
Handle SDPA backward if needed (likely unnecessary with decompositions)
Phase 3: Training loop integration
Define a functional training step pattern: (params, data) → (updated_params, metrics) where all state is explicit
Support optimizers as traced functions (SGD, Adam are simple arithmetic)
Handle gradient accumulation, gradient clipping, mixed precision as traced
operations
What this unlocks
Compiler-optimized training on MLIR backends (Linalg → CPU/GPU codegen)
Training on custom hardware via StableHLO/TOSA backends
Fused forward+backward kernels — since the full training graph is visible,
backends can fuse across the forward/backward boundary
Training-time quantization — QAT (quantization-aware training) with
fake_quant ops in the training graph
What this does NOT require
No autograd implementation in MLIR — PyTorch handles differentiation
No backward op lowering for 100+ ops — decompositions reduce to ~3 irreducible
backward ops
No mutable state handling — functional training steps make state explicit
No custom autograd Function support — users trace the full step including
custom backward logic
4. Remove PT1 Legacy Code
Impact: High | Effort: Low-Medium
The best effort-to-impact ratio of any project here. Removing 63K+ lines of
deprecated, unmaintained code immediately simplifies the build, CI, and
contributor experience.
Problem
The PT1 (PyTorch 1.x) infrastructure — TorchScript import, JIT IR importer, and
Lazy Tensor Core backend — is officially deprecated and unmaintained. The
CMakeLists.txt itself states:
NOTE: The JIT_IR_IMPORTER paths have become unsupportable due to age and lack
of maintainers. Turning this off disables the old TorchScript path, leaving FX
based import as the current supported option.
Yet the code remains:
~10K lines of C++ across projects/pt1/, projects/ltc/, projects/jit_ir_common/
~53K lines of Python including the e2e test suite, JIT IR importer, LTC
backend
Build system complexity: TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS, TORCH_MLIR_ENABLE_JIT_IR_IMPORTER, TORCH_MLIR_ENABLE_LTC flags
Hard dependency on libtorch C++ SDK when extensions are enabled
5,000+ lines of xfail sets for legacy test configurations
Approach
Relocate shared infrastructure that lives under projects/pt1/ but is
still needed:
abstract_interp_lib_gen.py and ODS update scripts in projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/
E2e test framework (framework.py, registry.py) if used by non-PT1 paths
Simplify CI: remove PyTorch source build requirement for the JIT path
Clean up Python packaging: remove legacy torchscript.compile() API
What this unlocks
Dramatically simpler build and CI configuration
Moves toward a pure-Python core pipeline (FX importer already is)
Reduces cognitive overhead for new contributors (one import path, not three)
Eliminates maintenance burden for code with no active maintainer
5. Stride and Memory Format Handling
Impact: Medium-High | Effort: High
Important for correctness and performance of channels-last models, but requires
cross-cutting architectural changes to the type system, importer, and all
backend lowerings.
Problem
Tensor strides are completely dropped during FX import. The Torch dialect
type system (!torch.vtensor) only carries sizes, dtype, and sparsity —
there is no representation for memory layout or strides.
This means:
All tensors are implicitly assumed contiguous (C-order)
Channels-last convolutions cannot be represented or optimized
Strided GEMMs and transposed weight matrices lose layout information
View/reshape on non-contiguous tensors may produce incorrect results
A TODO at fx_importer.py:1069 acknowledges this:
# Principally, this includes the FunctionType, but in the future,# it should also return other annotations (input strides, etc) that# affect compilation and should be included as arg attrs.
The data is available — torch.export.ExportedProgram captures full stride
information in FakeTensor metadata. The importer just doesn't use it.
Approach
Extend TorchTypes.td to add an optional layout/stride attribute to !torch.vtensor
Capture stride info from FakeTensor during FX import (the data is already
there in tensor.stride())
Insert layout conversion ops at boundaries where layout mismatches occur
(e.g., torch.prim.contiguous coercion)
Propagate layout through shape inference and into backend lowering
Extend symbolic shape handling to also track symbolic strides (the
infrastructure for torch.bind_symbolic_shape already exists)
What this unlocks
Correct lowering of channels-last models (critical for mobile/edge inference)
Layout-aware backend optimizations
Proper handling of transposed views and non-contiguous slices
Foundation for future memory format optimizations
6. Shape Inference Improvements
Impact: Medium | Effort: Medium-High
The current shape inference system works for most models but has known
limitations that affect edge cases and compilation efficiency.
Problem
The shape inference system has several compounding limitations:
Mandatory inlining (Passes.cpp:63-66):
// Currently, our shape inference is not powerful enough to deal with// calls, so inline everything.// TODO: Improve shape inference.
pm.addPass(createInlinerPass());
All functions must be inlined before shape inference, bloating IR and preventing
modular compilation.
Data-dependent shapes use a hack (abstract_interp_lib_gen.py):
defhacky_get_unknown_dimension_size():
"""Breaks the invariant that shape functions are executable code"""returnid(DummyClassType())
Ops like nonzero, unique, masked_select, and bincount cannot have output
shapes inferred.
Meta tensor limitations: ~9 shape/dtype functions cannot run on the Meta
backend and require CPU workarounds. The codebase has TODOs noting this should be
fixed by migrating to FakeTensor.
Potential improvements (by impact)
Inter-procedural shape inference: Infer shapes across function call
boundaries without inlining — reduces IR size for models with many submodules
FakeTensor migration: Replace Meta tensor backend with FakeTensor for
shape/dtype inference (fixes 9+ known bugs)
Data-dependent shape bounds: Track upper bounds for data-dependent
dimensions rather than giving up entirely
Control flow shape tracking: Propagate shapes through prim::Loop and prim::If bodies
7. TorchOnnxToTorch Maintainability
Impact: Medium | Effort: Medium
Internal code quality improvement. Important for long-term sustainability but
does not directly add new capabilities for users.
Problem
The ONNX-to-Torch conversion implements ~200 ops as ad-hoc lambda callbacks
across 4 monolithic files (the largest being ~5,000 lines). Key issues:
~30 TODOs for missing features (per-channel quantization, dynamic shapes, auto_pad in convolutions)
No reusable infrastructure — each op is a bespoke lambda with repeated
boilerplate
Only 2 hardcoded domains (default + com.microsoft) with no extension mechanism
Silent failures via notifyMatchFailure() make debugging difficult
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Future Projects for Torch-MLIR
I'd like to open this discussion to get some feedback on various project ideas that I have come across- to get a sense of whether these have impact for our current users and how they could be implemented. This is not a formal RFC. If you feel strongly about any of these and want to implement the feature, please open an RFC yourself, and feel free to link the discussion for reference.
Overview of Projects
This document outlines concrete improvement opportunities for torch-mlir,
ordered by estimated impact (highest first). Impact considers user-facing value,
how blocking the current state is, and what new capabilities are unlocked.
Note: This document was authored in a conversation with Claude Opus 4.6.
1. Leverage PyTorch Decompositions to Reduce Op Surface Area
Impact: Very High | Effort: Medium (incremental)
This is the highest-impact project because it is foundational — it reduces
maintenance across the entire codebase, simplifies backend work, and directly
enables training support (section 3).
Problem
torch-mlir maintains ~230 decomposition patterns in
DecomposeComplexOps.cpp(~13,600 lines of C++). Meanwhile, PyTorch's
torch._decomp.core_aten_decompositions()provides ~1,000 decompositions that reduce complex ops down to ~250 "core aten"
primitives. torch-mlir currently uses only ~52 of these via a hand-curated list
in
python/torch_mlir/extras/fx_decomp_util.py.This means torch-mlir is re-implementing in C++ many decompositions that PyTorch
already provides in Python — and must maintain them as PyTorch evolves.
Architecture context
Today, both the FX and ONNX import paths produce "full" Torch dialect ops (e.g.,
aten.layer_norm,aten.softmax).DecomposeComplexOpsis the single placethat reduces these to primitives, and
LowerToBackendContractenforces thateverything was decomposed. Backends can exempt specific ops from decomposition
via the
backendLegalOpsmechanism (LowerToBackendContract.cpp:600-602):This lets a backend say "keep
addmmfused" even though a decomposition exists.The multi-frontend complication
PyTorch-side decompositions via
prog.run_decompositions()happen before MLIRimport, which creates asymmetry between frontends:
already exists in
fx.py)TorchOnnxToTorchproduces Torch ops directly and never goesthrough
torch.export, so PyTorch-side decompositions are unavailableThis means:
DecomposeComplexOpsis still required for the ONNX path regardlessbackendLegalOpsescape hatch cannot reach PyTorch-side decompositions —if PyTorch decomposes
addmmbefore import, a backend that wanted fusedaddmmhas no way to prevent itproduce different IR for the same model, complicating testing and backend
support
Possible approaches
A) Hybrid: PyTorch decompositions for "safe" ops only (recommended)
Use PyTorch-side decompositions on the FX path only for ops where no backend
would want the non-decomposed form — activations, loss functions, statistical
reductions, simple math. Keep
DecomposeComplexOpsas the canonicaldecomposition layer for everything else.
Categories of MLIR-side patterns that could be eliminated this way:
celu,elu,glu,hardsigmoid,hardswish,leaky_relu,mish,prelu,relu6,selu,silu,softplus,softshrink,hardshrink,rrelu, etc.binary_cross_entropy_with_logits,cross_entropy_loss,kl_div,l1_loss,mse_loss,poisson_nll_loss,soft_margin_lossaddcdiv,addcmul,deg2rad,square,frac,lerp,logaddexp,copysign, etc.std,var,var_mean,meanchannel_shuffle,pixel_shuffle,pixel_unshuffle,roll,rot90,tile,stack,column_stackThe FX decomposition table flips from opt-in to opt-out:
The corresponding
DecomposeComplexOpspatterns andmarkDecomposedOpsAsIllegalentries can be deleted incrementally as PyTorch-side decomps prove stable. The
ONNX path continues to rely on
DecomposeComplexOpsfor these same ops — theC++ patterns are only removable once both paths are covered.
Pro: Incremental, preserves
backendLegalOpsfor ops that need it, reduces C++maintenance for the common case.
Con: Two decomposition sites to reason about; ONNX path doesn't benefit.
B) Define a "core" op set and have all frontends target it
Define an explicit set of ~250 primitive torch ops that backends must implement.
core_aten_decompositions()to reach that set before importTorchOnnxToTorchlower ONNX ops directly to primitives,skipping intermediate "full" torch ops
DecomposeComplexOpsbecomes a thin pass for MLIR-specific structuraltransforms only
Pro: One op set, dramatically reduced C++ code, PyTorch maintains decompositions.
Con: Requires significant rework of
TorchOnnxToTorchto target primitivesinstead of the full torch op set. Loses
backendLegalOpsflexibility forPyTorch-decomposed ops.
C) Keep DecomposeComplexOps as the single source of truth (status quo)
Both paths import into the "full" Torch op set.
DecomposeComplexOpsis the oneplace that defines the decomposition contract. PyTorch-side decompositions are
used only for ops that torch-mlir doesn't handle at all (the current 52-op
list).
Pro: One decomposition contract, works for all frontends,
backendLegalOpsworks everywhere.
Con: Maintaining ~230 C++ decomposition patterns indefinitely.
What stays in DecomposeComplexOps regardless
~100-130 patterns that are structural to torch-mlir's type system and have no
PyTorch equivalent:
AtenContiguousOp,AtenCopyOp,AtenExpandOp)Prims*op decompositionscore aten decomposed form
Rollout plan (for approach A)
DecomposeComplexOpspatterns againstcore_aten_decompositions()entriesbackendLegalOpsto keepthem, and PyTorch's decomposed form targets ops backends already handle
Risks
than the original (e.g., decomposing
addmmintomm+addloses fusionopportunities) — mitigated by the opt-out list
improvements) but could introduce regressions
DecomposeComplexOpsfor all migrated ops, so C++patterns can only be deleted once the ONNX path is also addressed
2. First-Class PT2E Quantization Support
Impact: Very High | Effort: Medium-High
Quantized inference is a primary deployment use case, and the modern PyTorch
quantization path (PT2E) is effectively broken in torch-mlir. The current shim
converts PT2E ops backwards into a legacy representation, losing information and
limiting coverage.
Background: What PT2E quantization actually produces
PT2E quantization (via
torch.ao.quantization/torchao.quantization.pt2e)exports graphs where quantization is expressed as explicit
quantized_decomposedops surrounding normal float compute. There are no quantized compute ops —
the compiler is expected to fuse
dequant → float_op → quantinto integerarithmetic.
A quantized linear+relu looks like this in the exported graph:
Key properties:
quantized types
in the graph
The
quantized_decomposednamespace defines 19 ops covering:quantize_per_tensor/dequantize_per_tensor(3 overloads each: scalar,tensor, tensor2 for scale/zp args)
quantize_per_channel/dequantize_per_channelquantize_per_token/dequantize_per_tokenquantize_per_channel_group/dequantize_per_channel_groupchoose_qparamsvariants (per-tensor, symmetric, per-token)fake_quant_per_channelconvert_element_type.no_fuseNote: PT2E quantization is migrating from
torch.ao.quantizationto torchao(
torchao.quantization.pt2e), but the underlyingquantized_decomposedopdefinitions and graph representation remain the same.
Current state in torch-mlir
torch-mlir's quantization is built on a legacy representation:
!torch.qint8types,
aten.quantize_per_tensor(which produces quantized-typed tensors), andaten._make_per_tensor_quantized_tensor+aten.dequantizepairs. A shim pass(
MatchQuantizedCustomOps) converts 3 of the 19 PT2E ops into this oldrepresentation. This is backwards — it forces the modern representation through a
legacy bottleneck.
What first-class PT2E support looks like
The design principle: keep the PT2E representation as the canonical form
throughout the pipeline. Don't convert to old-style QDQ. The old-style
representation (
!torch.qint8,aten.quantize_per_tensor,aten._make_per_tensor_quantized_tensor) should be treated as legacy/ONNX-only.Stage 1: FX Import
The FX importer needs to handle
quantized_decomposedops natively. Today theyarrive as
torch.operatorstrings because the ops aren't registered in the Torchdialect. Two options:
Option A: Define first-class ODS ops. Add all 19
quantized_decomposedopsto the Torch dialect with proper type constraints, shape inference, and
verifiers. This is the cleanest but requires maintaining ODS definitions in sync
with PyTorch.
Option B: Import as
torch.operatorbut with proper type annotations. Keepthe generic
torch.operatorrepresentation but ensure the FX importer setscorrect result types (integer tensors, not float). This is less work but provides
no static verification.
Option A is preferred for a first-class experience.
The
is_quantizedcheck infx_importer.py:1220should be handled, but inpractice PT2E graphs use plain integer tensor types between q/dq ops — they do
not use PyTorch's
quantizedtensor subclass. So the existing check may noteven fire for PT2E models. The real issue is that
quantized_decomposedopsaren't recognized.
Stage 2: Torch Dialect — Quantization-Aware Optimization
With the PT2E ops in the IR, the next step is pattern-matching
dq → op → qsequences and deciding what to do with them. This replaces the current
FuseQuantizedOpspass with something more general.The key design question: should we fuse into quantized compute ops, or lower
q/dq directly to backends?
Approach: Don't fuse in the Torch dialect. Instead, keep the
dq → float_op → qpattern intact and let each backend handle fusion during lowering. This issimpler and more flexible:
dq → linalg.matmul → qand emitlinalg.quantized_matmulor integerlinalg.genericwith scale/zp mathconv2d/matmul with explicit scale attributes)
stablehlo.uniform_quantize/stablehlo.uniform_dequantizeThe Torch dialect passes should focus on propagation and simplification, not
fusion:
dq(q(const))intoa quantized constant directly
q → dqpairs that cancel outtranspose, pad, slice) so backends see them at the right point
Stage 3: Backend Lowering
Each backend has native quantization support with different representations.
The lowering should pattern-match the PT2E ops directly:
TorchToLinalg:
For ops where integer fusion isn't available (or isn't profitable), the backend
can simply lower
dequantize_per_tensorto(int_to_float(x) - zp) * scaleand
quantize_per_tensortoclamp(round(x / scale) + zp, qmin, qmax)asscalar arithmetic. This is always correct, just not as fast as fused integer
compute.
TorchToTosa:
TOSA natively supports quantized integer tensors with scale/zp as op attributes.
Lower
dq → conv → qdirectly totosa.conv2dwith integer inputs and thequantization parameters as attributes.
TorchToStablehlo:
StableHLO has
uniform_quantize/uniform_dequantizeops and supportsquantized types via MLIR's quant dialect. Lower
quantized_decomposedopsdirectly to these.
Stage 4: Deprecate old-style QDQ
Once PT2E is the canonical path:
QuantizeLinear/DequantizeLinearshould lower toquantized_decomposedops instead of old-styleaten.quantize_per_tensorMatchQuantizedCustomOps(no longer needed — PT2E ops are first-class)!torch.qint8,!torch.quint8) and ops(
aten._make_per_tensor_quantized_tensor,aten.int_repr) from the requiredbackend support surface
FuseQuantizedOps(fusion happens at backend lowering)Per-tensor vs per-channel vs per-token
The design must handle all granularities from the start:
[out_channels][batch, seq_len, 1][channels, groups]Per-channel and per-token are critical for LLM quantization (GPTQ, AWQ,
SmoothQuant). These should not be an afterthought.
What this unlocks
per-channel-group
backends
APIs are deprecated
3. Training Support
Impact: High | Effort: Medium
Training opens an entirely new use case for torch-mlir. The key insight is that
PyTorch's decomposition machinery (section 1) eliminates almost all backward ops
before they reach torch-mlir, making training support far more feasible than
commonly assumed.
The surprising finding: training is closer than expected
A common assumption is that training requires supporting dozens of backward ops
and building autograd infrastructure in MLIR. In practice, PyTorch's own
decomposition machinery eliminates almost all backward ops before they would
reach torch-mlir.
Using
make_fx+torch.autograd.grad+core_aten_decompositions(), a fulltraining step (forward + backward + gradient computation) can be traced into a
flat FX graph of primitive ops. Most backward ops decompose entirely into
ordinary forward ops:
mm,t,mul, etc.where,le,mulconvolution_backward(irreducible)max_pool2d_with_indices_backward(irreducible)avg_pool2d_backward(irreducible)A transformer training step (attention + layernorm + loss + grads) produced a
graph with zero backward ops — just 23 ordinary ops like
mm,bmm,view,mul,sub,where,native_layer_norm,_softmax. These are all opstorch-mlir already lowers.
PyTorch provides 104 backward op decompositions in
core_aten_decompositions(). Only a handful of backward ops are irreducible:convolution_backward,max_pool2d_with_indices_backward,avg_pool2d_backward, and a few others. torch-mlir already has lowerings forconvolution_backward(TorchToLinalg) andmax_pool2d_with_indices_backward(TorchToTMTensor).
What a training step looks like
The resulting
fx_gis aGraphModulecontaining only primitive forward opsthat torch-mlir already knows how to lower.
Blockers
1. No
make_fximport pathThe FX importer (
fx_importer.py) expectstorch.export.ExportedProgram.make_fxproduces atorch.fx.GraphModule— a different format. Options:GraphModuleintoExportedProgram: PyTorch has internal utilitiesfor this, but they're not stable public API
GraphModuleimport directly: Add a parallel import path in theFX importer that handles
GraphModulewithout requiringExportedProgrammetadata (graph signatures, range constraints, etc.)
aot_export_joint_simple: This functorch API traces jointforward+backward and may produce output closer to what the importer expects
2.
torch.exportis inference-onlytorch.export.export()does not capture backward graphs. There is anexport_for_training()API, but it is deprecated and produces a forward-onlygraph with autograd metadata — not a traced backward pass. The standard way to
get a joint forward+backward graph is
make_fxoraot_autograd, which arelower-level functorch APIs.
3. Parameter mutation / weight updates
A full training loop includes
params = params - lr * grads. This is trivialarithmetic, but the pipeline needs to handle it:
updated parameters (functional style)
make_fxtraces work — all state is explicit4. Missing backward op lowerings (small set)
Only 2-3 irreducible backward ops need lowering support that doesn't fully exist:
avg_pool2d_backward— no lowering currently_scaled_dot_product_flash_attention_for_cpu_backward— no lowering (but SDPAdecomposes fully anyway with
core_aten_decompositions)convolution_backward— already has TorchToLinalg lowering and adecomposition in
DecomposeComplexOps.cpp5. Dynamic shapes in training
Training graphs have more dynamic shape requirements than inference:
Approach
Phase 1: Import
make_fxtraining graphsGraphModuleimport support to the FX importer (or a thin wrapper thatconverts
GraphModule→ExportedProgram)core_aten_decompositions()during tracing to eliminate backward opsbackend pipelines
Phase 2: Fill remaining op gaps
avg_pool2d_backwardlowering to TorchToLinalgconvolution_backwardworks end-to-end (the lowering exists but mayhave edge cases — some conv backward tests are marked as hanging in xfail
sets)
Phase 3: Training loop integration
(params, data) → (updated_params, metrics)where all state is explicitoperations
What this unlocks
backends can fuse across the forward/backward boundary
fake_quant ops in the training graph
What this does NOT require
backward ops
custom backward logic
4. Remove PT1 Legacy Code
Impact: High | Effort: Low-Medium
The best effort-to-impact ratio of any project here. Removing 63K+ lines of
deprecated, unmaintained code immediately simplifies the build, CI, and
contributor experience.
Problem
The PT1 (PyTorch 1.x) infrastructure — TorchScript import, JIT IR importer, and
Lazy Tensor Core backend — is officially deprecated and unmaintained. The
CMakeLists.txt itself states:
Yet the code remains:
projects/pt1/,projects/ltc/,projects/jit_ir_common/backend
TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS,TORCH_MLIR_ENABLE_JIT_IR_IMPORTER,TORCH_MLIR_ENABLE_LTCflagsApproach
projects/pt1/but isstill needed:
abstract_interp_lib_gen.pyand ODS update scripts inprojects/pt1/python/torch_mlir/jit_ir_importer/build_tools/framework.py,registry.py) if used by non-PT1 pathsprojects/pt1/,projects/ltc/,projects/jit_ir_common/TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS,TORCH_MLIR_ENABLE_JIT_IR_IMPORTER,TORCH_MLIR_ENABLE_LTCflags andassociated CMake logic
torchscript.compile()APIWhat this unlocks
5. Stride and Memory Format Handling
Impact: Medium-High | Effort: High
Important for correctness and performance of channels-last models, but requires
cross-cutting architectural changes to the type system, importer, and all
backend lowerings.
Problem
Tensor strides are completely dropped during FX import. The Torch dialect
type system (
!torch.vtensor) only carriessizes,dtype, andsparsity—there is no representation for memory layout or strides.
This means:
A TODO at
fx_importer.py:1069acknowledges this:The data is available —
torch.export.ExportedProgramcaptures full strideinformation in FakeTensor metadata. The importer just doesn't use it.
Approach
TorchTypes.tdto add an optional layout/stride attribute to!torch.vtensorthere in
tensor.stride())(e.g.,
torch.prim.contiguouscoercion)infrastructure for
torch.bind_symbolic_shapealready exists)What this unlocks
6. Shape Inference Improvements
Impact: Medium | Effort: Medium-High
The current shape inference system works for most models but has known
limitations that affect edge cases and compilation efficiency.
Problem
The shape inference system has several compounding limitations:
Mandatory inlining (
Passes.cpp:63-66):All functions must be inlined before shape inference, bloating IR and preventing
modular compilation.
Data-dependent shapes use a hack (
abstract_interp_lib_gen.py):Ops like
nonzero,unique,masked_select, andbincountcannot have outputshapes inferred.
Meta tensor limitations: ~9 shape/dtype functions cannot run on the Meta
backend and require CPU workarounds. The codebase has TODOs noting this should be
fixed by migrating to FakeTensor.
Potential improvements (by impact)
boundaries without inlining — reduces IR size for models with many submodules
shape/dtype inference (fixes 9+ known bugs)
dimensions rather than giving up entirely
prim::Loopandprim::Ifbodies7. TorchOnnxToTorch Maintainability
Impact: Medium | Effort: Medium
Internal code quality improvement. Important for long-term sustainability but
does not directly add new capabilities for users.
Problem
The ONNX-to-Torch conversion implements ~200 ops as ad-hoc lambda callbacks
across 4 monolithic files (the largest being ~5,000 lines). Key issues:
auto_padin convolutions)boilerplate
notifyMatchFailure()make debugging difficultPotential improvements
reduction, broadcasting) into reusable helpers
already does)
per opset version
Beta Was this translation helpful? Give feedback.
All reactions