@@ -16,14 +16,18 @@ def setUp(self):
1616 torch ._dynamo .reset ()
1717
1818 class MeanDim (torch .nn .Module ):
19- def __init__ (self , dims , keepdim = True ):
19+ def __init__ (self , dims , keepdim = True , dtype = None ):
2020 super ().__init__ ()
2121 self .dims = dims
2222 self .keepdim = keepdim
23+ self .dtype = dtype
2324
2425 def forward (self , x ):
2526 y = x + x
26- z = torch .mean (y , self .dims , keepdim = self .keepdim )
27+ if self .dtype is None :
28+ z = torch .mean (y , self .dims , keepdim = self .keepdim )
29+ else :
30+ z = torch .mean (y , self .dims , keepdim = self .keepdim , dtype = self .dtype )
2731 return z
2832
2933 def _test_mean_dim (self , inputs , dims = (- 1 , - 2 )):
@@ -88,6 +92,16 @@ def test_fp32_mean_dim_unsupported_keepdim_false(self):
8892 .check_count ({"executorch_exir_dialects_edge__ops_aten_mean_dim" : 1 })
8993 )
9094
95+ def test_fp32_mean_dim_unsupported_dtype (self ):
96+ inputs = (torch .randn (1 , 5 , 4 , 4 ),)
97+ (
98+ Tester (self .MeanDim ((- 1 , - 2 ), dtype = torch .float64 ), inputs )
99+ .export ()
100+ .check_count ({"torch.ops.aten.mean.dim" : 1 })
101+ .to_edge_transform_and_lower ()
102+ .check_count ({"executorch_exir_dialects_edge__ops_aten_mean_dim" : 1 })
103+ )
104+
91105 def test_qs8_mean_dim (self ):
92106 inputs = (torch .randn (1 , 5 , 4 , 4 ),)
93107 (
0 commit comments