Skip to content

Commit 78ec2b4

Browse files
Apply suggestions from code review
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 538ce0a commit 78ec2b4

1 file changed

Lines changed: 19 additions & 12 deletions

File tree

backends/arm/_passes/dim_maps.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class ViewMap:
196196
Additional conditions apply for the map being valid depending on if the mapped dim
197197
is a reduction operator or a permutation operator, as described in the respective methods.
198198
199-
SymInts are partialy supported by factorizing them as single primes as the true
199+
SymInts are partially supported by factorizing them as single primes as the true
200200
value is not known, causing potentially fewer valid mappings.
201201
202202
"""
@@ -515,21 +515,24 @@ def _build_groups(
515515
) -> _ViewGroups | None:
516516
"""Build source/target axis groups from ordered prime factors."""
517517

518-
# Compute orderd prime factorizations of input and output shapes
518+
# Compute ordered prime factorizations of input and output shapes
519519
source_factors = _factor_shape(source_shape)
520520
target_factors = _factor_shape(target_shape)
521-
assert (
522-
source_factors is not None
523-
and (target_factors is not None)
524-
and Counter(factor.key for factor in source_factors)
525-
== Counter(factor.key for factor in target_factors)
526-
), "Invalid view shapes"
521+
if (
522+
source_factors is None
523+
or target_factors is None
524+
or Counter(factor.key for factor in source_factors)
525+
!= Counter(factor.key for factor in target_factors)
526+
):
527+
return None
528+
source_factors = cast(list[_Factor], source_factors)
529+
target_factors = cast(list[_Factor], target_factors)
527530

528531
# Compute prime factor permutation between input and output shapes
529532
factor_count = len(source_factors)
530533
permutation = cls._find_permutation(source_factors, target_factors)
531-
assert permutation is not None, "Invalid view shapes"
532-
534+
if permutation is None:
535+
return None
533536
# Find groups of factors that must be mapped together to preserve view equivalence
534537
union_find = _UnionFind(factor_count)
535538
cls._union_factors_sharing_axes(
@@ -626,8 +629,12 @@ def __init__(self, permute_node: Node) -> None:
626629
assert isinstance(permute_dims, Sequence) and not isinstance(
627630
permute_dims, (str, bytes)
628631
)
629-
self.permute_dims = list(cast(Sequence[int], permute_dims))
630-
632+
normalized = _normalize_permutation(
633+
cast(Sequence[int], permute_dims), len(cast(Sequence[int], permute_dims))
634+
)
635+
if normalized is None:
636+
raise ValueError(f"Invalid permute dims: {permute_dims}")
637+
self.permute_dims = normalized
631638
def map_dims(self, dims: int | Sequence[int]) -> list[int]:
632639
"""Computes mapped dims s.t.
633640

0 commit comments

Comments
 (0)