@@ -196,7 +196,7 @@ class ViewMap:
196196 Additional conditions apply for the map being valid depending on if the mapped dim
197197 is a reduction operator or a permutation operator, as described in the respective methods.
198198
199- SymInts are partialy supported by factorizing them as single primes as the true
199+ SymInts are partially supported by factorizing them as single primes as the true
200200 value is not known, causing potentially fewer valid mappings.
201201
202202 """
@@ -515,21 +515,24 @@ def _build_groups(
515515 ) -> _ViewGroups | None :
516516 """Build source/target axis groups from ordered prime factors."""
517517
518- # Compute orderd prime factorizations of input and output shapes
518+ # Compute ordered prime factorizations of input and output shapes
519519 source_factors = _factor_shape (source_shape )
520520 target_factors = _factor_shape (target_shape )
521- assert (
522- source_factors is not None
523- and (target_factors is not None )
524- and Counter (factor .key for factor in source_factors )
525- == Counter (factor .key for factor in target_factors )
526- ), "Invalid view shapes"
521+ if (
522+ source_factors is None
523+ or target_factors is None
524+ or Counter (factor .key for factor in source_factors )
525+ != Counter (factor .key for factor in target_factors )
526+ ):
527+ return None
528+ source_factors = cast (list [_Factor ], source_factors )
529+ target_factors = cast (list [_Factor ], target_factors )
527530
528531 # Compute prime factor permutation between input and output shapes
529532 factor_count = len (source_factors )
530533 permutation = cls ._find_permutation (source_factors , target_factors )
531- assert permutation is not None , "Invalid view shapes"
532-
534+ if permutation is None :
535+ return None
533536 # Find groups of factors that must be mapped together to preserve view equivalence
534537 union_find = _UnionFind (factor_count )
535538 cls ._union_factors_sharing_axes (
@@ -626,8 +629,12 @@ def __init__(self, permute_node: Node) -> None:
626629 assert isinstance (permute_dims , Sequence ) and not isinstance (
627630 permute_dims , (str , bytes )
628631 )
629- self .permute_dims = list (cast (Sequence [int ], permute_dims ))
630-
632+ normalized = _normalize_permutation (
633+ cast (Sequence [int ], permute_dims ), len (cast (Sequence [int ], permute_dims ))
634+ )
635+ if normalized is None :
636+ raise ValueError (f"Invalid permute dims: { permute_dims } " )
637+ self .permute_dims = normalized
631638 def map_dims (self , dims : int | Sequence [int ]) -> list [int ]:
632639 """Computes mapped dims s.t.
633640
0 commit comments