@@ -709,3 +709,81 @@ func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
709709 %12 = torch.prim.ListConstruct %11 : (!torch.int ) -> !torch.list <int >
710710 return %7 : !torch.vtensor <[2 ,2 ,2 ],si64 >
711711}
712+
713+ // -----
714+
715+ // select.int on cat of constants and dynamic — folds constant elements.
716+ // CHECK-LABEL: @select_int_from_cat_fold
717+ func.func @select_int_from_cat_fold (%arg0: !torch.vtensor <[1 ,?,2048 ],f16 >, %arg1: !torch.int ) -> !torch.vtensor <[?,?,?,?],f16 > {
718+ // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
719+ // CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
720+ // CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
721+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %arg1, %[[INT16]], %[[INT128]]
722+ // CHECK: %[[RESULT:.*]] = torch.aten.reshape %arg0, %[[LIST]]
723+ // CHECK: return %[[RESULT]]
724+ %int0 = torch.constant.int 0
725+ %int1 = torch.constant.int 1
726+ %int2 = torch.constant.int 2
727+ %int3 = torch.constant.int 3
728+ %c1 = torch.vtensor.literal (dense <1 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
729+ %c16 = torch.vtensor.literal (dense <16 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
730+ %c128 = torch.vtensor.literal (dense <128 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
731+ %dyn = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.vtensor <[],si64 >
732+ %dyn_unsq = torch.aten.unsqueeze %dyn , %int0 : !torch.vtensor <[],si64 >, !torch.int -> !torch.vtensor <[1 ],si64 >
733+ %list = torch.prim.ListConstruct %c1 , %dyn_unsq , %c16 , %c128 : (!torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >) -> !torch.list <vtensor >
734+ %cat = torch.aten.cat %list , %int0 : !torch.list <vtensor >, !torch.int -> !torch.vtensor <[4 ],si64 >
735+ %s0 = torch.aten.select.int %cat , %int0 , %int0 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
736+ %d0 = torch.aten.item %s0 : !torch.vtensor <[1 ],si64 > -> !torch.int
737+ %s1 = torch.aten.select.int %cat , %int0 , %int1 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
738+ %d1 = torch.aten.item %s1 : !torch.vtensor <[1 ],si64 > -> !torch.int
739+ %s2 = torch.aten.select.int %cat , %int0 , %int2 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
740+ %d2 = torch.aten.item %s2 : !torch.vtensor <[1 ],si64 > -> !torch.int
741+ %s3 = torch.aten.select.int %cat , %int0 , %int3 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
742+ %d3 = torch.aten.item %s3 : !torch.vtensor <[1 ],si64 > -> !torch.int
743+ %shape = torch.prim.ListConstruct %d0 , %d1 , %d2 , %d3 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
744+ %result = torch.aten.reshape %arg0 , %shape : !torch.vtensor <[1 ,?,2048 ],f16 >, !torch.list <int > -> !torch.vtensor <[?,?,?,?],f16 >
745+ return %result : !torch.vtensor <[?,?,?,?],f16 >
746+ }
747+
748+ // -----
749+
750+ // select.int with negative index — selects last element.
751+ // CHECK-LABEL: @select_int_negative_index
752+ func.func @select_int_negative_index (%arg0: !torch.int ) -> !torch.list <int > {
753+ // CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
754+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT128]]
755+ // CHECK: return %[[LIST]]
756+ %int0 = torch.constant.int 0
757+ %int_neg1 = torch.constant.int -1
758+ %c1 = torch.vtensor.literal (dense <1 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
759+ %c128 = torch.vtensor.literal (dense <128 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
760+ %dyn = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor <[],si64 >
761+ %dyn_unsq = torch.aten.unsqueeze %dyn , %int0 : !torch.vtensor <[],si64 >, !torch.int -> !torch.vtensor <[1 ],si64 >
762+ %list = torch.prim.ListConstruct %c1 , %dyn_unsq , %c128 : (!torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >) -> !torch.list <vtensor >
763+ %cat = torch.aten.cat %list , %int0 : !torch.list <vtensor >, !torch.int -> !torch.vtensor <[3 ],si64 >
764+ %sel = torch.aten.select.int %cat , %int0 , %int_neg1 : !torch.vtensor <[3 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
765+ %result = torch.aten.item %sel : !torch.vtensor <[1 ],si64 > -> !torch.int
766+ %shape = torch.prim.ListConstruct %result : (!torch.int ) -> !torch.list <int >
767+ return %shape : !torch.list <int >
768+ }
769+
770+ // -----
771+
772+ // select.int on cat with multi-element sub-tensor.
773+ // cat([vtensor<[2]>, vtensor<[1]>]) produces [3], select at index 1.
774+ // CHECK-LABEL: @select_int_multi_element_subtensor
775+ func.func @select_int_multi_element_subtensor () -> !torch.list <int > {
776+ // CHECK-DAG: %[[INT42:.*]] = torch.constant.int 42
777+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT42]]
778+ // CHECK: return %[[LIST]]
779+ %int0 = torch.constant.int 0
780+ %int1 = torch.constant.int 1
781+ %c = torch.vtensor.literal (dense <[10 , 42 ]> : tensor <2 xsi64 >) : !torch.vtensor <[2 ],si64 >
782+ %c2 = torch.vtensor.literal (dense <99 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
783+ %list = torch.prim.ListConstruct %c , %c2 : (!torch.vtensor <[2 ],si64 >, !torch.vtensor <[1 ],si64 >) -> !torch.list <vtensor >
784+ %cat = torch.aten.cat %list , %int0 : !torch.list <vtensor >, !torch.int -> !torch.vtensor <[3 ],si64 >
785+ %sel = torch.aten.select.int %cat , %int0 , %int1 : !torch.vtensor <[3 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
786+ %result = torch.aten.item %sel : !torch.vtensor <[1 ],si64 > -> !torch.int
787+ %shape = torch.prim.ListConstruct %result : (!torch.int ) -> !torch.list <int >
788+ return %shape : !torch.list <int >
789+ }
0 commit comments