Skip to content

Commit 236427f

Browse files
committed
[ty] Reachability and narrowing for enum methods
1 parent aca8ba7 commit 236427f

File tree

4 files changed

+117
-2
lines changed

4 files changed

+117
-2
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
379379
# this diagnostic is correct: the inferred type of `subject` is `str`
380380
assert_never(subject) # error: [type-assertion-failure]
381381
```
382+
383+
## Exhaustiveness checking for methods of enums
384+
385+
```py
386+
from enum import Enum
387+
388+
class Answer(Enum):
389+
YES = "yes"
390+
NO = "no"
391+
392+
def is_yes(self) -> bool:
393+
reveal_type(self) # revealed: Self@is_yes
394+
395+
match self:
396+
case Answer.YES:
397+
return True
398+
case Answer.NO:
399+
return False
400+
```

crates/ty_python_semantic/resources/mdtest/narrow/match.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,51 @@ match x:
252252

253253
reveal_type(x) # revealed: object
254254
```
255+
256+
## Narrowing on `Self` in `match` statements
257+
258+
When performing narrowing on `self` inside methods on enums, we take into account that `Self` might
259+
refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify
260+
`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able
261+
to return `self` in the `assert_yes` method below:
262+
263+
```py
264+
from enum import Enum
265+
from typing_extensions import Self, assert_never
266+
267+
class Answer(Enum):
268+
NO = 0
269+
YES = 1
270+
MAYBE = 2
271+
272+
def is_yes(self) -> bool:
273+
reveal_type(self) # revealed: Self@is_yes
274+
275+
match self:
276+
case Answer.YES:
277+
reveal_type(self) # revealed: Self@is_yes
278+
return True
279+
case Answer.NO | Answer.MAYBE:
280+
reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES]
281+
return False
282+
case _:
283+
assert_never(self) # no error
284+
285+
def assert_yes(self) -> Self:
286+
reveal_type(self) # revealed: Self@assert_yes
287+
288+
match self:
289+
case Answer.YES:
290+
reveal_type(self) # revealed: Self@assert_yes
291+
return self
292+
case _:
293+
reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES]
294+
raise ValueError("Answer is not YES")
295+
296+
Answer.YES.is_yes()
297+
298+
try:
299+
reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE]
300+
except ValueError:
301+
pass
302+
```

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,27 @@ impl ReachabilityConstraints {
802802
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
803803
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
804804

805-
let narrowed_subject_ty = IntersectionBuilder::new(db)
805+
let narrowed_subject = IntersectionBuilder::new(db)
806806
.add_positive(subject_ty)
807-
.add_negative(type_excluded_by_previous_patterns(db, predicate))
807+
.add_negative(type_excluded_by_previous_patterns(db, predicate));
808+
809+
let narrowed_subject_ty = narrowed_subject.clone().build();
810+
811+
// Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`, where
812+
// `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed subject
813+
// type is `Self & ~Literal[NO]`. This type is *not* equivalent to `Literal[YES]`, because `Self`
814+
// could also specialize to `Literal[NO]` or `Never`, making the intersection empty. However, if the
815+
// current pattern matches on `YES`, the *next* narrowed subject type will be `Self & ~Literal[NO] &
816+
// ~Literal[YES]`, which *is* equivalent to `Never`. This means that subsequent patterns can never
817+
// match. And we know that if we reach this point, the current pattern will have to match. We return
818+
// `AlwaysTrue` here, since the call to `analyze_single_pattern_predicate_kind` below would return
819+
// `Ambiguous` in this case.
820+
let next_narrowed_subject_ty = narrowed_subject
821+
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
808822
.build();
823+
if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() {
824+
return Truthiness::AlwaysTrue;
825+
}
809826

810827
let truthiness = Self::analyze_single_pattern_predicate_kind(
811828
db,

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,37 @@ impl<'db> IntersectionBuilder<'db> {
780780
seen_aliases,
781781
)
782782
}
783+
Type::EnumLiteral(enum_literal) => {
784+
let enum_class = enum_literal.enum_class(self.db);
785+
let metadata =
786+
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
787+
788+
let enum_members_in_negative_part = self
789+
.intersections
790+
.iter()
791+
.flat_map(|intersection| &intersection.negative)
792+
.filter_map(|ty| ty.as_enum_literal())
793+
.filter(|lit| lit.enum_class(self.db) == enum_class)
794+
.map(|lit| lit.name(self.db).clone())
795+
.chain(std::iter::once(enum_literal.name(self.db).clone()))
796+
.collect::<FxOrderSet<_>>();
797+
798+
let all_members_are_in_negative_part = metadata
799+
.members
800+
.keys()
801+
.all(|name| enum_members_in_negative_part.contains(name));
802+
803+
if all_members_are_in_negative_part {
804+
for inner in &mut self.intersections {
805+
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
806+
}
807+
} else {
808+
for inner in &mut self.intersections {
809+
inner.add_negative(self.db, ty);
810+
}
811+
}
812+
self
813+
}
783814
_ => {
784815
for inner in &mut self.intersections {
785816
inner.add_negative(self.db, ty);

0 commit comments

Comments
 (0)