Skip to content

Commit 1b3ddee

Browse files
committed
Model equality comparison domains
1 parent 86cc4cf commit 1b3ddee

1 file changed

Lines changed: 59 additions & 27 deletions

File tree

crates/ty_python_semantic/src/types/equality.rs

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,40 +69,16 @@ pub(super) fn evaluate_type_equality<'db>(
6969
enum_literal_constraint(db, left, right, ComparisonOperator::Equality, is_positive)
7070
.or_else(|| primitive_literal_constraint(db, left, right, is_positive))
7171
.or_else(|| {
72-
if is_equality_narrowing_operand(db, left, right) {
72+
if comparison_domain(db, left, right, ComparisonOperator::Equality)
73+
== ComparisonDomain::Known
74+
{
7375
equality_result(db, left, right, is_positive).constraint(is_positive)
7476
} else {
7577
None
7678
}
7779
})
7880
}
7981

80-
/// Return whether `ty` can constrain `left` through equality.
81-
///
82-
/// The general evaluator distributes over unions, so avoid invoking it for ordinary nominal types
83-
/// whose comparison semantics are unknown.
84-
fn is_equality_narrowing_operand<'db>(db: &'db dyn Db, left: Type<'db>, ty: Type<'db>) -> bool {
85-
match ty.resolve_type_alias(db) {
86-
Type::Union(union) => union
87-
.elements(db)
88-
.iter()
89-
.all(|element| is_equality_narrowing_operand(db, left, *element)),
90-
Type::LiteralValue(_) | Type::EnumComplement(_) => true,
91-
Type::Intersection(intersection) if intersection.enum_complement(db).is_some() => true,
92-
Type::TypedDict(_) => true,
93-
Type::NominalInstance(instance) => {
94-
instance.tuple_spec(db).is_some()
95-
|| instance
96-
.class(db)
97-
.known(db)
98-
.is_some_and(|known| known == KnownClass::Bool || known.is_singleton())
99-
|| left.resolve_type_alias(db).is_union()
100-
&& comparison_semantics(db, ty, ComparisonOperator::Equality).is_some()
101-
}
102-
ty => ty.is_single_valued(db),
103-
}
104-
}
105-
10682
pub(super) fn evaluate_type_inequality<'db>(
10783
db: &'db dyn Db,
10884
left: Type<'db>,
@@ -1035,6 +1011,62 @@ enum KnownComparisonSemantics {
10351011
Dict,
10361012
}
10371013

1014+
/// Whether the non-target operand has a comparison domain that can safely constrain the target.
1015+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1016+
enum ComparisonDomain {
1017+
/// The operand may use comparison behavior that `ty` does not model.
1018+
Unknown,
1019+
/// The operand can be handled by `ty`'s equality-narrowing evaluator.
1020+
Known,
1021+
}
1022+
1023+
/// Classify whether `ty` has comparison behavior that can constrain `target`.
1024+
///
1025+
/// Unions only have a known domain if every arm does. Broad nominal types require full dunder
1026+
/// analysis, which is only useful here when it can eliminate an arm from a union target.
1027+
fn comparison_domain<'db>(
1028+
db: &'db dyn Db,
1029+
target: Type<'db>,
1030+
ty: Type<'db>,
1031+
operator: ComparisonOperator,
1032+
) -> ComparisonDomain {
1033+
let target = target.resolve_type_alias(db);
1034+
let ty = ty.resolve_type_alias(db);
1035+
1036+
match ty {
1037+
Type::Union(union) => {
1038+
if union.elements(db).iter().all(|element| {
1039+
comparison_domain(db, target, *element, operator) == ComparisonDomain::Known
1040+
}) {
1041+
ComparisonDomain::Known
1042+
} else {
1043+
ComparisonDomain::Unknown
1044+
}
1045+
}
1046+
Type::LiteralValue(_) | Type::EnumComplement(_) | Type::TypedDict(_) => {
1047+
ComparisonDomain::Known
1048+
}
1049+
Type::Intersection(intersection) if intersection.enum_complement(db).is_some() => {
1050+
ComparisonDomain::Known
1051+
}
1052+
Type::NominalInstance(instance) => {
1053+
if instance.tuple_spec(db).is_some()
1054+
|| instance
1055+
.class(db)
1056+
.known(db)
1057+
.is_some_and(|known| known == KnownClass::Bool || known.is_singleton())
1058+
|| target.is_union() && comparison_semantics(db, ty, operator).is_some()
1059+
{
1060+
ComparisonDomain::Known
1061+
} else {
1062+
ComparisonDomain::Unknown
1063+
}
1064+
}
1065+
_ if ty.is_single_valued(db) => ComparisonDomain::Known,
1066+
_ => ComparisonDomain::Unknown,
1067+
}
1068+
}
1069+
10381070
fn comparison_semantics<'db>(
10391071
db: &'db dyn Db,
10401072
ty: Type<'db>,

0 commit comments

Comments
 (0)