Skip to content

Commit 0489cf3

Browse files
committed
Resolve Co-pilot review comments
1 parent 2c1369e commit 0489cf3

4 files changed

Lines changed: 48 additions & 10 deletions

File tree

backends/xnnpack/operators/op_mean_dim.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def define_node(
5858
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
5959
)
6060

61+
# This visitor serializes mean.dim as Global Average Pooling, which has
62+
# no field for an explicit dtype override.
63+
check_or_raise(
64+
node.kwargs.get("dtype") is None,
65+
"XNNPACK does not support mean.dim with dtype",
66+
)
67+
6168
# mean dims
6269
mean_dims = normalize_mean_dims(node.args[1], len(input_shape))
6370
check_or_raise(

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,20 +517,30 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
517517
return False
518518

519519
input_rank = get_input_node(node, 0).meta["val"].dim()
520-
keepdim = len(node.args) >= 3 and bool(node.args[2])
521-
dims = normalize_mean_dims(node.args[1], input_rank)
522-
523-
if sorted(dims) != [2, 3]:
520+
if input_rank != 4:
524521
why(
525522
node,
526-
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
523+
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
527524
)
528525
return False
529526

530-
if input_rank != 4:
527+
# This path lowers mean.dim to XNNPACK Global Average Pooling, which
528+
# cannot encode an explicit dtype override.
529+
if node.kwargs.get("dtype") is not None:
530+
why(node, reason="mean.dim does not support dtype")
531+
return False
532+
533+
keepdim = len(node.args) >= 3 and bool(node.args[2])
534+
try:
535+
dims = normalize_mean_dims(node.args[1], input_rank)
536+
except ValueError as error:
537+
why(node, reason=f"mean.dim has invalid dims: {error}")
538+
return False
539+
540+
if sorted(dims) != [2, 3]:
531541
why(
532542
node,
533-
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
543+
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
534544
)
535545
return False
536546

backends/xnnpack/test/ops/test_mean_dim.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@ def setUp(self):
1616
torch._dynamo.reset()
1717

1818
class MeanDim(torch.nn.Module):
19-
def __init__(self, dims, keepdim=True):
19+
def __init__(self, dims, keepdim=True, dtype=None):
2020
super().__init__()
2121
self.dims = dims
2222
self.keepdim = keepdim
23+
self.dtype = dtype
2324

2425
def forward(self, x):
2526
y = x + x
26-
z = torch.mean(y, self.dims, keepdim=self.keepdim)
27+
if self.dtype is None:
28+
z = torch.mean(y, self.dims, keepdim=self.keepdim)
29+
else:
30+
z = torch.mean(y, self.dims, keepdim=self.keepdim, dtype=self.dtype)
2731
return z
2832

2933
def _test_mean_dim(self, inputs, dims=(-1, -2)):
@@ -88,6 +92,16 @@ def test_fp32_mean_dim_unsupported_keepdim_false(self):
8892
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
8993
)
9094

95+
def test_fp32_mean_dim_unsupported_dtype(self):
96+
inputs = (torch.randn(1, 5, 4, 4),)
97+
(
98+
Tester(self.MeanDim((-1, -2), dtype=torch.float64), inputs)
99+
.export()
100+
.check_count({"torch.ops.aten.mean.dim": 1})
101+
.to_edge_transform_and_lower()
102+
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
103+
)
104+
91105
def test_qs8_mean_dim(self):
92106
inputs = (torch.randn(1, 5, 4, 4),)
93107
(

backends/xnnpack/utils/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,18 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
6060

6161
def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]:
6262
"""Return mean dims as non-negative indices for the given rank."""
63+
if rank <= 0:
64+
raise ValueError(f"Expected rank > 0, got {rank}")
6365
if mean_dims is None:
6466
return list(range(rank))
6567
if isinstance(mean_dims, int):
6668
mean_dims = [mean_dims]
67-
return [dim % rank for dim in mean_dims]
69+
normalized_dims = []
70+
for dim in mean_dims:
71+
if dim < -rank or dim >= rank:
72+
raise ValueError(f"Dimension out of range: {dim} for rank {rank}")
73+
normalized_dims.append(dim % rank)
74+
return normalized_dims
6875

6976

7077
def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:

0 commit comments

Comments
 (0)