We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 78ec2b4 commit cb08a9eCopy full SHA for cb08a9e
1 file changed
backends/arm/_passes/dim_maps.py
@@ -525,8 +525,8 @@ def _build_groups(
525
!= Counter(factor.key for factor in target_factors)
526
):
527
return None
528
- source_factors = cast(list[_Factor], source_factors)
529
- target_factors = cast(list[_Factor], target_factors)
+ source_factors = source_factors
+ target_factors = target_factors
530
531
# Compute prime factor permutation between input and output shapes
532
factor_count = len(source_factors)
@@ -635,6 +635,7 @@ def __init__(self, permute_node: Node) -> None:
635
if normalized is None:
636
raise ValueError(f"Invalid permute dims: {permute_dims}")
637
self.permute_dims = normalized
638
+
639
def map_dims(self, dims: int | Sequence[int]) -> list[int]:
640
"""Computes mapped dims s.t.
641
0 commit comments