11use ruff_python_ast:: name:: Name ;
2+ use rustc_hash:: FxHashSet ;
23
34use crate :: { Db , place:: PlaceAndQualifiers } ;
45
@@ -109,10 +110,39 @@ fn equality_result<'db>(
109110 let left = left. resolve_type_alias ( db) ;
110111 let right = right. resolve_type_alias ( db) ;
111112
112- if let Some ( alternatives) = equality_alternatives ( db, left) {
113+ if !finite_domain_expansion_is_bounded ( db, left, right, ComparisonOperator :: Equality ) {
114+ return ComparisonResult :: Ambiguous ;
115+ }
116+
117+ let left_alternatives = finite_alternatives ( db, left, ComparisonOperator :: Equality ) ;
118+ if left == right
119+ && let Some ( alternatives) = left_alternatives. as_deref ( )
120+ {
121+ return if alternatives. len ( ) == 1 {
122+ ComparisonResult :: AlwaysTrue
123+ } else {
124+ ComparisonResult :: Ambiguous
125+ } ;
126+ }
127+
128+ let right_alternatives = finite_alternatives ( db, right, ComparisonOperator :: Equality ) ;
129+ if let ( Some ( left_alternatives) , Some ( right_alternatives) ) =
130+ ( left_alternatives. as_deref ( ) , right_alternatives. as_deref ( ) )
131+ {
132+ return evaluate_finite_domains (
133+ db,
134+ left_alternatives,
135+ right_alternatives,
136+ is_positive,
137+ ComparisonOperator :: Equality ,
138+ equality_result,
139+ ) ;
140+ }
141+
142+ if let Some ( alternatives) = left_alternatives {
113143 return evaluate_union_left ( db, & alternatives, right, is_positive, equality_result) ;
114144 }
115- if let Some ( alternatives) = equality_alternatives ( db , right ) {
145+ if let Some ( alternatives) = right_alternatives {
116146 return evaluate_union_right ( db, left, & alternatives, is_positive, equality_result) ;
117147 }
118148
@@ -378,10 +408,39 @@ fn inequality_result<'db>(
378408 let left = left. resolve_type_alias ( db) ;
379409 let right = right. resolve_type_alias ( db) ;
380410
381- if let Some ( alternatives) = inequality_alternatives ( db, left) {
411+ if !finite_domain_expansion_is_bounded ( db, left, right, ComparisonOperator :: Inequality ) {
412+ return ComparisonResult :: Ambiguous ;
413+ }
414+
415+ let left_alternatives = finite_alternatives ( db, left, ComparisonOperator :: Inequality ) ;
416+ if left == right
417+ && let Some ( alternatives) = left_alternatives. as_deref ( )
418+ {
419+ return if alternatives. len ( ) == 1 {
420+ ComparisonResult :: AlwaysFalse
421+ } else {
422+ ComparisonResult :: Ambiguous
423+ } ;
424+ }
425+
426+ let right_alternatives = finite_alternatives ( db, right, ComparisonOperator :: Inequality ) ;
427+ if let ( Some ( left_alternatives) , Some ( right_alternatives) ) =
428+ ( left_alternatives. as_deref ( ) , right_alternatives. as_deref ( ) )
429+ {
430+ return evaluate_finite_domains (
431+ db,
432+ left_alternatives,
433+ right_alternatives,
434+ is_positive,
435+ ComparisonOperator :: Inequality ,
436+ inequality_result,
437+ ) ;
438+ }
439+
440+ if let Some ( alternatives) = left_alternatives {
382441 return evaluate_union_left ( db, & alternatives, right, is_positive, inequality_result) ;
383442 }
384- if let Some ( alternatives) = inequality_alternatives ( db , right ) {
443+ if let Some ( alternatives) = right_alternatives {
385444 return evaluate_union_right ( db, left, & alternatives, is_positive, inequality_result) ;
386445 }
387446
@@ -766,12 +825,167 @@ fn evaluate_intersection_left<'db>(
766825 }
767826}
768827
769- fn equality_alternatives < ' db > ( db : & ' db dyn Db , ty : Type < ' db > ) -> Option < Vec < Type < ' db > > > {
770- finite_alternatives ( db, ty, ComparisonOperator :: Equality )
828+ fn evaluate_finite_domains < ' db > (
829+ db : & ' db dyn Db ,
830+ left : & [ Type < ' db > ] ,
831+ right : & [ Type < ' db > ] ,
832+ is_positive : bool ,
833+ operator : ComparisonOperator ,
834+ evaluate : fn ( & ' db dyn Db , Type < ' db > , Type < ' db > , bool ) -> ComparisonResult < ' db > ,
835+ ) -> ComparisonResult < ' db > {
836+ if left. is_empty ( ) || right. is_empty ( ) {
837+ return ComparisonResult :: Ambiguous ;
838+ }
839+
840+ let Some ( other_keys) = right
841+ . iter ( )
842+ . map ( |alternative| finite_comparison_key ( db, * alternative, operator) )
843+ . collect :: < Option < FxHashSet < _ > > > ( )
844+ else {
845+ return ComparisonResult :: Ambiguous ;
846+ } ;
847+
848+ evaluate_target_union ( db, left, is_positive, |alternative| {
849+ if let Some ( key) = finite_comparison_key ( db, alternative, operator) {
850+ let equality = if !other_keys. contains ( & key) {
851+ ComparisonResult :: AlwaysFalse
852+ } else if other_keys. len ( ) == 1 {
853+ ComparisonResult :: AlwaysTrue
854+ } else {
855+ ComparisonResult :: Ambiguous
856+ } ;
857+ if operator == ComparisonOperator :: Equality {
858+ equality
859+ } else {
860+ equality. negate ( )
861+ }
862+ } else {
863+ evaluate_against_results (
864+ db,
865+ alternative,
866+ is_positive,
867+ right
868+ . iter ( )
869+ . map ( |other| evaluate ( db, alternative, * other, is_positive) ) ,
870+ )
871+ }
872+ } )
873+ }
874+
875+ fn finite_comparison_key < ' db > (
876+ db : & ' db dyn Db ,
877+ ty : Type < ' db > ,
878+ operator : ComparisonOperator ,
879+ ) -> Option < Type < ' db > > {
880+ let literal = match ty {
881+ Type :: LiteralValue ( literal) => literal. kind ( ) ,
882+ Type :: Intersection ( intersection) => LiteralValueTypeKind :: Enum (
883+ intersection
884+ . positive ( db)
885+ . iter ( )
886+ . find_map ( |element| element. as_enum_literal ( ) ) ?,
887+ ) ,
888+ _ if has_known_identity_comparison_semantics ( db, ty, operator) => return Some ( ty) ,
889+ _ => return None ,
890+ } ;
891+
892+ match literal {
893+ LiteralValueTypeKind :: Bool ( value) => Some ( Type :: int_literal ( i64:: from ( value) ) ) ,
894+ LiteralValueTypeKind :: Int ( _)
895+ | LiteralValueTypeKind :: String ( _)
896+ | LiteralValueTypeKind :: Bytes ( _) => Some ( ty) ,
897+ LiteralValueTypeKind :: LiteralString => None ,
898+ LiteralValueTypeKind :: Enum ( enum_literal) => {
899+ match known_instance_semantics ( db, enum_literal. enum_class_instance ( db) , operator) ? {
900+ KnownComparisonSemantics :: Object => {
901+ let metadata = enum_metadata ( db, enum_literal. enum_class ( db) ) ?;
902+ let name = metadata. resolve_member ( enum_literal. name ( db) ) ?;
903+ Some ( Type :: enum_literal ( EnumLiteralType :: new (
904+ db,
905+ enum_literal. enum_class ( db) ,
906+ name. clone ( ) ,
907+ ) ) )
908+ }
909+ KnownComparisonSemantics :: Int
910+ | KnownComparisonSemantics :: Str
911+ | KnownComparisonSemantics :: Bytes => {
912+ finite_comparison_key ( db, enum_literal_value ( db, enum_literal) ?, operator)
913+ }
914+ KnownComparisonSemantics :: Tuple | KnownComparisonSemantics :: Dict => None ,
915+ }
916+ }
917+ }
918+ }
919+
920+ #[ derive( Debug , Copy , Clone , PartialEq , Eq ) ]
921+ enum FiniteComparisonDomain {
922+ None ,
923+ Finite ,
924+ Mixed ,
771925}
772926
773- fn inequality_alternatives < ' db > ( db : & ' db dyn Db , ty : Type < ' db > ) -> Option < Vec < Type < ' db > > > {
774- finite_alternatives ( db, ty, ComparisonOperator :: Inequality )
927+ fn finite_domain_expansion_is_bounded (
928+ db : & dyn Db ,
929+ target : Type ,
930+ other : Type ,
931+ operator : ComparisonOperator ,
932+ ) -> bool {
933+ !matches ! ( other, Type :: Union ( _) )
934+ || finite_comparison_domain ( db, target, operator) == FiniteComparisonDomain :: None
935+ || finite_comparison_domain ( db, other, operator) == FiniteComparisonDomain :: Finite
936+ }
937+
938+ fn finite_comparison_domain (
939+ db : & dyn Db ,
940+ ty : Type ,
941+ operator : ComparisonOperator ,
942+ ) -> FiniteComparisonDomain {
943+ match ty {
944+ Type :: Union ( union) => {
945+ let mut has_finite = false ;
946+ let mut has_open = false ;
947+ for element in union. elements ( db) {
948+ match finite_comparison_domain ( db, * element, operator) {
949+ FiniteComparisonDomain :: None => has_open = true ,
950+ FiniteComparisonDomain :: Finite => has_finite = true ,
951+ FiniteComparisonDomain :: Mixed => return FiniteComparisonDomain :: Mixed ,
952+ }
953+ }
954+ match ( has_finite, has_open) {
955+ ( true , true ) => FiniteComparisonDomain :: Mixed ,
956+ ( true , false ) => FiniteComparisonDomain :: Finite ,
957+ ( false , _) => FiniteComparisonDomain :: None ,
958+ }
959+ }
960+ Type :: EnumComplement ( _) => {
961+ if comparison_semantics ( db, ty, operator) . is_some ( ) {
962+ FiniteComparisonDomain :: Finite
963+ } else {
964+ FiniteComparisonDomain :: None
965+ }
966+ }
967+ Type :: Intersection ( intersection) => {
968+ if intersection. enum_complement ( db) . is_some ( )
969+ && comparison_semantics ( db, ty, operator) . is_some ( )
970+ {
971+ FiniteComparisonDomain :: Finite
972+ } else {
973+ FiniteComparisonDomain :: None
974+ }
975+ }
976+ _ if finite_comparison_key ( db, ty, operator) . is_some ( ) => FiniteComparisonDomain :: Finite ,
977+ Type :: NominalInstance ( instance) => {
978+ if instance. has_known_class ( db, KnownClass :: Bool )
979+ || ( enum_metadata ( db, instance. class_literal ( db) ) . is_some ( )
980+ && comparison_semantics ( db, ty, operator) . is_some ( ) )
981+ {
982+ FiniteComparisonDomain :: Finite
983+ } else {
984+ FiniteComparisonDomain :: None
985+ }
986+ }
987+ _ => FiniteComparisonDomain :: None ,
988+ }
775989}
776990
777991fn finite_alternatives < ' db > (
@@ -780,6 +994,23 @@ fn finite_alternatives<'db>(
780994 operator : ComparisonOperator ,
781995) -> Option < Vec < Type < ' db > > > {
782996 match ty {
997+ Type :: Union ( union) => {
998+ let mut alternatives = Vec :: new ( ) ;
999+ let mut expanded_finite_domain = false ;
1000+ for element in union. elements ( db) {
1001+ if let Some ( element_alternatives) = finite_alternatives ( db, * element, operator) {
1002+ alternatives. extend ( element_alternatives) ;
1003+ expanded_finite_domain = true ;
1004+ } else {
1005+ alternatives. push ( * element) ;
1006+ }
1007+ }
1008+ ( expanded_finite_domain
1009+ || alternatives
1010+ . iter ( )
1011+ . all ( |alternative| finite_comparison_key ( db, * alternative, operator) . is_some ( ) ) )
1012+ . then_some ( alternatives)
1013+ }
7831014 Type :: EnumComplement ( complement) => comparison_semantics ( db, ty, operator)
7841015 . is_some ( )
7851016 . then ( || complement. remaining_literal_types ( db) ) ,
0 commit comments