Skip to content

Commit 90fc215

Browse files
jtuylsclaude
andauthored
[ScalarizeShapes] Fold select.int through cat in shape computations (#4513)
Extend getListFromTensor to recurse into aten.cat operands and add a select.int folding pattern. This resolves shape elements from concat-based shape tensors used by onnx.Reshape lowering. --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 613c355 commit 90fc215

2 files changed

Lines changed: 171 additions & 4 deletions

File tree

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,26 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
109109
return success();
110110
}
111111

112+
// aten.cat of 1D tensors: recurse into each element.
113+
if (auto catOp = value.getDefiningOp<Torch::AtenCatOp>()) {
114+
int64_t catDim;
115+
if (matchPattern(catOp.getDim(), m_TorchConstantInt(&catDim)) &&
116+
catDim == 0) {
117+
SmallVector<Value> tensors;
118+
if (succeeded(getListOperands(catOp.getTensors(), tensors))) {
119+
SmallVector<OpFoldResult> catElements;
120+
if (llvm::all_of(tensors,
121+
[&](Value t) {
122+
return succeeded(getListFromTensor(t, catElements));
123+
}) &&
124+
(int64_t)catElements.size() <= kMaxFold) {
125+
vals.append(catElements.begin(), catElements.end());
126+
return success();
127+
}
128+
}
129+
}
130+
}
131+
112132
// Last supported case: ValueTensorLiteralOp
113133
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
114134
if (!literalOp)
@@ -357,6 +377,74 @@ class PropagateAtenIndexSelectPattern
357377
};
358378
} // namespace
359379

380+
namespace {
381+
// Fold `aten.select.int(1d_tensor, 0, const_idx)` by extracting the i-th
382+
// scalar element via getListFromTensor (which handles literals, unsqueeze,
383+
// NumToTensor, cat, etc.).
384+
class PropagateAtenSelectIntPattern : public OpRewritePattern<AtenSelectIntOp> {
385+
public:
386+
using OpRewritePattern<AtenSelectIntOp>::OpRewritePattern;
387+
LogicalResult matchAndRewrite(AtenSelectIntOp op,
388+
PatternRewriter &rewriter) const override {
389+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
390+
391+
int64_t dim;
392+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
393+
return rewriter.notifyMatchFailure(op, "requires a constant dim");
394+
395+
int64_t idx;
396+
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&idx)))
397+
return rewriter.notifyMatchFailure(op, "requires a constant index");
398+
399+
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
400+
if (!selfTy.hasSizes() || selfTy.getSizes().size() != 1)
401+
return rewriter.notifyMatchFailure(op, "expected 1D input");
402+
403+
int64_t selfRank = selfTy.getSizes().size();
404+
dim = toPositiveDim(dim, selfRank);
405+
if (!isValidDim(dim, selfRank))
406+
return rewriter.notifyMatchFailure(op, "invalid dim");
407+
408+
int64_t dimLength = selfTy.getSizes()[dim];
409+
if (dimLength == kUnknownSize)
410+
return rewriter.notifyMatchFailure(op, "unknown dim length");
411+
412+
idx = toPositiveDim(idx, dimLength);
413+
if (!isValidDim(idx, dimLength))
414+
return rewriter.notifyMatchFailure(op, "invalid index");
415+
416+
SmallVector<OpFoldResult> elements;
417+
if (failed(getListFromTensor(op.getSelf(), elements)) ||
418+
idx >= (int64_t)elements.size())
419+
return rewriter.notifyMatchFailure(op, "cannot decompose source tensor");
420+
421+
SmallVector<Value, 1> materialized;
422+
SmallVector<OpFoldResult, 1> single = {elements[idx]};
423+
if (failed(materializeFolds(b, single, materialized)))
424+
return failure();
425+
426+
// `prim.NumToTensor.Scalar`'s shape function returns rank-0, so build it
427+
// with a rank-0 result type. If the original `aten.select.int` produced a
428+
// rank-1 `[1]` tensor (as in ONNX→Torch lowerings of `onnx.Gather`),
429+
// unsqueeze back to match. The existing `getListFromTensor` already folds
430+
// through `unsqueeze(NumToTensor(scalar))`, so downstream propagation
431+
// patterns still see straight through the replacement.
432+
auto resultTy = cast<ValueTensorType>(op.getType());
433+
auto rank0Ty = rewriter.getType<Torch::ValueTensorType>(
434+
ArrayRef<int64_t>({}), resultTy.getDtype());
435+
Value rank0 =
436+
PrimNumToTensorScalarOp::create(b, rank0Ty, materialized.front());
437+
Value result = rank0;
438+
if (!resultTy.hasSizes() || !resultTy.getSizes().empty()) {
439+
Value zero = Torch::ConstantIntOp::create(b, 0);
440+
result = AtenUnsqueezeOp::create(b, resultTy, rank0, zero);
441+
}
442+
rewriter.replaceOp(op, result);
443+
return success();
444+
}
445+
};
446+
} // namespace
447+
360448
namespace {
361449
// Conversion attempts to handle some common propagatable slice cases, namely
362450
// splatted values, no-op slices, known list of values, or any case where a
@@ -1507,10 +1595,11 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
15071595
// are positive so floor divide should be a sufficient scalar replacement.
15081596
patterns.insert<
15091597
PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
1510-
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1511-
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1512-
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1513-
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
1598+
PropagateAtenSelectIntPattern, PropagateAtenItemPattern,
1599+
PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern,
1600+
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
1601+
PropagateAtenBroadcastToPattern, PropagateAtenTransposeIntPattern,
1602+
PropagateAtenToDtypePattern,
15141603
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
15151604
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
15161605
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,

test/Dialect/Torch/scalarize-shapes.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,3 +709,81 @@ func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
709709
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
710710
return %7 : !torch.vtensor<[2,2,2],si64>
711711
}
712+
713+
// -----
714+
715+
// select.int on cat of constants and dynamic — folds constant elements.
716+
// CHECK-LABEL: @select_int_from_cat_fold
717+
func.func @select_int_from_cat_fold(%arg0: !torch.vtensor<[1,?,2048],f16>, %arg1: !torch.int) -> !torch.vtensor<[?,?,?,?],f16> {
718+
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
719+
// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
720+
// CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
721+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %arg1, %[[INT16]], %[[INT128]]
722+
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %arg0, %[[LIST]]
723+
// CHECK: return %[[RESULT]]
724+
%int0 = torch.constant.int 0
725+
%int1 = torch.constant.int 1
726+
%int2 = torch.constant.int 2
727+
%int3 = torch.constant.int 3
728+
%c1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
729+
%c16 = torch.vtensor.literal(dense<16> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
730+
%c128 = torch.vtensor.literal(dense<128> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
731+
%dyn = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.vtensor<[],si64>
732+
%dyn_unsq = torch.aten.unsqueeze %dyn, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
733+
%list = torch.prim.ListConstruct %c1, %dyn_unsq, %c16, %c128 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
734+
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],si64>
735+
%s0 = torch.aten.select.int %cat, %int0, %int0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
736+
%d0 = torch.aten.item %s0 : !torch.vtensor<[1],si64> -> !torch.int
737+
%s1 = torch.aten.select.int %cat, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
738+
%d1 = torch.aten.item %s1 : !torch.vtensor<[1],si64> -> !torch.int
739+
%s2 = torch.aten.select.int %cat, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
740+
%d2 = torch.aten.item %s2 : !torch.vtensor<[1],si64> -> !torch.int
741+
%s3 = torch.aten.select.int %cat, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
742+
%d3 = torch.aten.item %s3 : !torch.vtensor<[1],si64> -> !torch.int
743+
%shape = torch.prim.ListConstruct %d0, %d1, %d2, %d3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
744+
%result = torch.aten.reshape %arg0, %shape : !torch.vtensor<[1,?,2048],f16>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f16>
745+
return %result : !torch.vtensor<[?,?,?,?],f16>
746+
}
747+
748+
// -----
749+
750+
// select.int with negative index — selects last element.
751+
// CHECK-LABEL: @select_int_negative_index
752+
func.func @select_int_negative_index(%arg0: !torch.int) -> !torch.list<int> {
753+
// CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
754+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT128]]
755+
// CHECK: return %[[LIST]]
756+
%int0 = torch.constant.int 0
757+
%int_neg1 = torch.constant.int -1
758+
%c1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
759+
%c128 = torch.vtensor.literal(dense<128> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
760+
%dyn = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
761+
%dyn_unsq = torch.aten.unsqueeze %dyn, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
762+
%list = torch.prim.ListConstruct %c1, %dyn_unsq, %c128 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
763+
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
764+
%sel = torch.aten.select.int %cat, %int0, %int_neg1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
765+
%result = torch.aten.item %sel : !torch.vtensor<[1],si64> -> !torch.int
766+
%shape = torch.prim.ListConstruct %result : (!torch.int) -> !torch.list<int>
767+
return %shape : !torch.list<int>
768+
}
769+
770+
// -----
771+
772+
// select.int on cat with multi-element sub-tensor.
773+
// cat([vtensor<[2]>, vtensor<[1]>]) produces [3], select at index 1.
774+
// CHECK-LABEL: @select_int_multi_element_subtensor
775+
func.func @select_int_multi_element_subtensor() -> !torch.list<int> {
776+
// CHECK-DAG: %[[INT42:.*]] = torch.constant.int 42
777+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT42]]
778+
// CHECK: return %[[LIST]]
779+
%int0 = torch.constant.int 0
780+
%int1 = torch.constant.int 1
781+
%c = torch.vtensor.literal(dense<[10, 42]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
782+
%c2 = torch.vtensor.literal(dense<99> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
783+
%list = torch.prim.ListConstruct %c, %c2 : (!torch.vtensor<[2],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
784+
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
785+
%sel = torch.aten.select.int %cat, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
786+
%result = torch.aten.item %sel : !torch.vtensor<[1],si64> -> !torch.int
787+
%shape = torch.prim.ListConstruct %result : (!torch.int) -> !torch.list<int>
788+
return %shape : !torch.list<int>
789+
}

0 commit comments

Comments
 (0)