Skip to content

Commit 801415f

Browse files
committed
Avoid quadratic enum comparison narrowing
1 parent f167a9c commit 801415f

3 files changed

Lines changed: 311 additions & 8 deletions

File tree

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,28 @@ def _(answer: IndependentEquality):
186186
reveal_type(answer) # revealed: IndependentEquality
187187
```
188188

189+
Finite domains remain narrowable when the other operand also includes an identity singleton:
190+
191+
```py
192+
from enum import Enum
193+
from typing import Literal
194+
195+
class Finite(Enum):
196+
FIRST = 1
197+
SECOND = 2
198+
199+
def _(value: Finite, other: Literal[Finite.FIRST] | None):
200+
if value == other:
201+
reveal_type(value) # revealed: Literal[Finite.FIRST]
202+
else:
203+
reveal_type(value) # revealed: Finite
204+
205+
if value != other:
206+
reveal_type(value) # revealed: Finite
207+
else:
208+
reveal_type(value) # revealed: Literal[Finite.FIRST]
209+
```
210+
189211
## Equality between concrete runtime classes
190212

191213
Types such as `bool`, `LiteralString`, and `TypedDict` correspond to specific runtime classes.

crates/ty_python_semantic/src/types/equality.rs

Lines changed: 239 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ruff_python_ast::name::Name;
2+
use rustc_hash::FxHashSet;
23

34
use crate::{Db, place::PlaceAndQualifiers};
45

@@ -109,10 +110,39 @@ fn equality_result<'db>(
109110
let left = left.resolve_type_alias(db);
110111
let right = right.resolve_type_alias(db);
111112

112-
if let Some(alternatives) = equality_alternatives(db, left) {
113+
if !finite_domain_expansion_is_bounded(db, left, right, ComparisonOperator::Equality) {
114+
return ComparisonResult::Ambiguous;
115+
}
116+
117+
let left_alternatives = finite_alternatives(db, left, ComparisonOperator::Equality);
118+
if left == right
119+
&& let Some(alternatives) = left_alternatives.as_deref()
120+
{
121+
return if alternatives.len() == 1 {
122+
ComparisonResult::AlwaysTrue
123+
} else {
124+
ComparisonResult::Ambiguous
125+
};
126+
}
127+
128+
let right_alternatives = finite_alternatives(db, right, ComparisonOperator::Equality);
129+
if let (Some(left_alternatives), Some(right_alternatives)) =
130+
(left_alternatives.as_deref(), right_alternatives.as_deref())
131+
{
132+
return evaluate_finite_domains(
133+
db,
134+
left_alternatives,
135+
right_alternatives,
136+
is_positive,
137+
ComparisonOperator::Equality,
138+
equality_result,
139+
);
140+
}
141+
142+
if let Some(alternatives) = left_alternatives {
113143
return evaluate_union_left(db, &alternatives, right, is_positive, equality_result);
114144
}
115-
if let Some(alternatives) = equality_alternatives(db, right) {
145+
if let Some(alternatives) = right_alternatives {
116146
return evaluate_union_right(db, left, &alternatives, is_positive, equality_result);
117147
}
118148

@@ -378,10 +408,39 @@ fn inequality_result<'db>(
378408
let left = left.resolve_type_alias(db);
379409
let right = right.resolve_type_alias(db);
380410

381-
if let Some(alternatives) = inequality_alternatives(db, left) {
411+
if !finite_domain_expansion_is_bounded(db, left, right, ComparisonOperator::Inequality) {
412+
return ComparisonResult::Ambiguous;
413+
}
414+
415+
let left_alternatives = finite_alternatives(db, left, ComparisonOperator::Inequality);
416+
if left == right
417+
&& let Some(alternatives) = left_alternatives.as_deref()
418+
{
419+
return if alternatives.len() == 1 {
420+
ComparisonResult::AlwaysFalse
421+
} else {
422+
ComparisonResult::Ambiguous
423+
};
424+
}
425+
426+
let right_alternatives = finite_alternatives(db, right, ComparisonOperator::Inequality);
427+
if let (Some(left_alternatives), Some(right_alternatives)) =
428+
(left_alternatives.as_deref(), right_alternatives.as_deref())
429+
{
430+
return evaluate_finite_domains(
431+
db,
432+
left_alternatives,
433+
right_alternatives,
434+
is_positive,
435+
ComparisonOperator::Inequality,
436+
inequality_result,
437+
);
438+
}
439+
440+
if let Some(alternatives) = left_alternatives {
382441
return evaluate_union_left(db, &alternatives, right, is_positive, inequality_result);
383442
}
384-
if let Some(alternatives) = inequality_alternatives(db, right) {
443+
if let Some(alternatives) = right_alternatives {
385444
return evaluate_union_right(db, left, &alternatives, is_positive, inequality_result);
386445
}
387446

@@ -766,12 +825,167 @@ fn evaluate_intersection_left<'db>(
766825
}
767826
}
768827

769-
fn equality_alternatives<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
770-
finite_alternatives(db, ty, ComparisonOperator::Equality)
828+
fn evaluate_finite_domains<'db>(
829+
db: &'db dyn Db,
830+
left: &[Type<'db>],
831+
right: &[Type<'db>],
832+
is_positive: bool,
833+
operator: ComparisonOperator,
834+
evaluate: fn(&'db dyn Db, Type<'db>, Type<'db>, bool) -> ComparisonResult<'db>,
835+
) -> ComparisonResult<'db> {
836+
if left.is_empty() || right.is_empty() {
837+
return ComparisonResult::Ambiguous;
838+
}
839+
840+
let Some(other_keys) = right
841+
.iter()
842+
.map(|alternative| finite_comparison_key(db, *alternative, operator))
843+
.collect::<Option<FxHashSet<_>>>()
844+
else {
845+
return ComparisonResult::Ambiguous;
846+
};
847+
848+
evaluate_target_union(db, left, is_positive, |alternative| {
849+
if let Some(key) = finite_comparison_key(db, alternative, operator) {
850+
let equality = if !other_keys.contains(&key) {
851+
ComparisonResult::AlwaysFalse
852+
} else if other_keys.len() == 1 {
853+
ComparisonResult::AlwaysTrue
854+
} else {
855+
ComparisonResult::Ambiguous
856+
};
857+
if operator == ComparisonOperator::Equality {
858+
equality
859+
} else {
860+
equality.negate()
861+
}
862+
} else {
863+
evaluate_against_results(
864+
db,
865+
alternative,
866+
is_positive,
867+
right
868+
.iter()
869+
.map(|other| evaluate(db, alternative, *other, is_positive)),
870+
)
871+
}
872+
})
873+
}
874+
875+
fn finite_comparison_key<'db>(
876+
db: &'db dyn Db,
877+
ty: Type<'db>,
878+
operator: ComparisonOperator,
879+
) -> Option<Type<'db>> {
880+
let literal = match ty {
881+
Type::LiteralValue(literal) => literal.kind(),
882+
Type::Intersection(intersection) => LiteralValueTypeKind::Enum(
883+
intersection
884+
.positive(db)
885+
.iter()
886+
.find_map(|element| element.as_enum_literal())?,
887+
),
888+
_ if has_known_identity_comparison_semantics(db, ty, operator) => return Some(ty),
889+
_ => return None,
890+
};
891+
892+
match literal {
893+
LiteralValueTypeKind::Bool(value) => Some(Type::int_literal(i64::from(value))),
894+
LiteralValueTypeKind::Int(_)
895+
| LiteralValueTypeKind::String(_)
896+
| LiteralValueTypeKind::Bytes(_) => Some(ty),
897+
LiteralValueTypeKind::LiteralString => None,
898+
LiteralValueTypeKind::Enum(enum_literal) => {
899+
match known_instance_semantics(db, enum_literal.enum_class_instance(db), operator)? {
900+
KnownComparisonSemantics::Object => {
901+
let metadata = enum_metadata(db, enum_literal.enum_class(db))?;
902+
let name = metadata.resolve_member(enum_literal.name(db))?;
903+
Some(Type::enum_literal(EnumLiteralType::new(
904+
db,
905+
enum_literal.enum_class(db),
906+
name.clone(),
907+
)))
908+
}
909+
KnownComparisonSemantics::Int
910+
| KnownComparisonSemantics::Str
911+
| KnownComparisonSemantics::Bytes => {
912+
finite_comparison_key(db, enum_literal_value(db, enum_literal)?, operator)
913+
}
914+
KnownComparisonSemantics::Tuple | KnownComparisonSemantics::Dict => None,
915+
}
916+
}
917+
}
918+
}
919+
920+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
921+
enum FiniteComparisonDomain {
922+
None,
923+
Finite,
924+
Mixed,
771925
}
772926

773-
fn inequality_alternatives<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
774-
finite_alternatives(db, ty, ComparisonOperator::Inequality)
927+
fn finite_domain_expansion_is_bounded(
928+
db: &dyn Db,
929+
target: Type,
930+
other: Type,
931+
operator: ComparisonOperator,
932+
) -> bool {
933+
!matches!(other, Type::Union(_))
934+
|| finite_comparison_domain(db, target, operator) == FiniteComparisonDomain::None
935+
|| finite_comparison_domain(db, other, operator) == FiniteComparisonDomain::Finite
936+
}
937+
938+
fn finite_comparison_domain(
939+
db: &dyn Db,
940+
ty: Type,
941+
operator: ComparisonOperator,
942+
) -> FiniteComparisonDomain {
943+
match ty {
944+
Type::Union(union) => {
945+
let mut has_finite = false;
946+
let mut has_open = false;
947+
for element in union.elements(db) {
948+
match finite_comparison_domain(db, *element, operator) {
949+
FiniteComparisonDomain::None => has_open = true,
950+
FiniteComparisonDomain::Finite => has_finite = true,
951+
FiniteComparisonDomain::Mixed => return FiniteComparisonDomain::Mixed,
952+
}
953+
}
954+
match (has_finite, has_open) {
955+
(true, true) => FiniteComparisonDomain::Mixed,
956+
(true, false) => FiniteComparisonDomain::Finite,
957+
(false, _) => FiniteComparisonDomain::None,
958+
}
959+
}
960+
Type::EnumComplement(_) => {
961+
if comparison_semantics(db, ty, operator).is_some() {
962+
FiniteComparisonDomain::Finite
963+
} else {
964+
FiniteComparisonDomain::None
965+
}
966+
}
967+
Type::Intersection(intersection) => {
968+
if intersection.enum_complement(db).is_some()
969+
&& comparison_semantics(db, ty, operator).is_some()
970+
{
971+
FiniteComparisonDomain::Finite
972+
} else {
973+
FiniteComparisonDomain::None
974+
}
975+
}
976+
_ if finite_comparison_key(db, ty, operator).is_some() => FiniteComparisonDomain::Finite,
977+
Type::NominalInstance(instance) => {
978+
if instance.has_known_class(db, KnownClass::Bool)
979+
|| (enum_metadata(db, instance.class_literal(db)).is_some()
980+
&& comparison_semantics(db, ty, operator).is_some())
981+
{
982+
FiniteComparisonDomain::Finite
983+
} else {
984+
FiniteComparisonDomain::None
985+
}
986+
}
987+
_ => FiniteComparisonDomain::None,
988+
}
775989
}
776990

777991
fn finite_alternatives<'db>(
@@ -780,6 +994,23 @@ fn finite_alternatives<'db>(
780994
operator: ComparisonOperator,
781995
) -> Option<Vec<Type<'db>>> {
782996
match ty {
997+
Type::Union(union) => {
998+
let mut alternatives = Vec::new();
999+
let mut expanded_finite_domain = false;
1000+
for element in union.elements(db) {
1001+
if let Some(element_alternatives) = finite_alternatives(db, *element, operator) {
1002+
alternatives.extend(element_alternatives);
1003+
expanded_finite_domain = true;
1004+
} else {
1005+
alternatives.push(*element);
1006+
}
1007+
}
1008+
(expanded_finite_domain
1009+
|| alternatives
1010+
.iter()
1011+
.all(|alternative| finite_comparison_key(db, *alternative, operator).is_some()))
1012+
.then_some(alternatives)
1013+
}
7831014
Type::EnumComplement(complement) => comparison_semantics(db, ty, operator)
7841015
.is_some()
7851016
.then(|| complement.remaining_literal_types(db)),

crates/ty_python_semantic/src/types/infer/tests.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,56 @@ def classify(message: Message) -> int:
113113
Ok(())
114114
}
115115

116+
#[test]
117+
fn enum_comparison_narrowing_avoids_quadratic_expansion() -> anyhow::Result<()> {
118+
let mut db = setup_db();
119+
let mut members = String::new();
120+
for index in 0..250 {
121+
writeln!(&mut members, " MEMBER_{index} = {index}")
122+
.expect("writing to a String cannot fail");
123+
}
124+
let source = format!(
125+
r#"
126+
from enum import Enum
127+
128+
class Left(Enum):
129+
{members}
130+
class Right(Enum):
131+
{members}
132+
def consume(value: object) -> None: ...
133+
134+
def compare_same(left: Left, right: Left) -> None:
135+
if left == right:
136+
consume(left)
137+
if left != right:
138+
consume(left)
139+
140+
def compare_optional(left: Left, right: Left | None) -> None:
141+
if left == right:
142+
consume(left)
143+
if left != right:
144+
consume(left)
145+
146+
def compare_different(left: Left, right: Right) -> None:
147+
if left == right:
148+
consume(left)
149+
if left != right:
150+
consume(left)
151+
"#
152+
);
153+
db.write_file("/src/a.py", source)?;
154+
155+
let start = Instant::now();
156+
assert_file_diagnostics(&db, "/src/a.py", &[]);
157+
assert!(
158+
start.elapsed() < Duration::from_secs(10),
159+
"enum comparison narrowing took {:?}",
160+
start.elapsed()
161+
);
162+
163+
Ok(())
164+
}
165+
116166
#[test]
117167
fn not_literal_string() -> anyhow::Result<()> {
118168
let mut db = setup_db();

0 commit comments

Comments
 (0)