Skip to content

Commit 2946e79

Browse files
committed
[ty] Reachability and narrowing for enum methods
1 parent aca8ba7 commit 2946e79

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
379379
# this diagnostic is correct: the inferred type of `subject` is `str`
380380
assert_never(subject) # error: [type-assertion-failure]
381381
```
382+
383+
## Exhaustiveness checking for methods of enums
384+
385+
```py
386+
from enum import Enum
387+
388+
class Answer(Enum):
389+
YES = "yes"
390+
NO = "no"
391+
392+
def is_yes(self) -> bool:
393+
reveal_type(self) # revealed: Self@is_yes
394+
395+
match self:
396+
case Answer.YES:
397+
return True
398+
case Answer.NO:
399+
return False
400+
```

crates/ty_python_semantic/resources/mdtest/narrow/match.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,53 @@ match x:
252252

253253
reveal_type(x) # revealed: object
254254
```
255+
256+
## Narrowing on `Self` in `match` statements
257+
258+
When performing narrowing on `self` inside methods on enums, we take into account that `Self` might
259+
refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify
260+
`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able
261+
to return `self` in the `assert_yes` method below:
262+
263+
```py
264+
from enum import Enum
265+
from typing_extensions import Self, assert_never
266+
267+
class Answer(Enum):
268+
NO = 0
269+
YES = 1
270+
MAYBE = 2
271+
272+
def is_yes(self) -> bool:
273+
reveal_type(self) # revealed: Self@is_yes
274+
275+
match self:
276+
case Answer.YES:
277+
reveal_type(self) # revealed: Self@is_yes
278+
return True
279+
case Answer.NO | Answer.MAYBE:
280+
reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES]
281+
return False
282+
case _:
283+
assert_never(self) # no error
284+
285+
def assert_yes(self) -> Self:
286+
reveal_type(self) # revealed: Self@assert_yes
287+
288+
match self:
289+
case Answer.YES:
290+
reveal_type(self) # revealed: Self@assert_yes
291+
return self
292+
case _:
293+
reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES]
294+
raise ValueError("Answer is not YES")
295+
296+
Answer.YES.is_yes()
297+
298+
try:
299+
reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE]
300+
except ValueError:
301+
pass
302+
```
303+
304+
We do

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,15 @@ impl ReachabilityConstraints {
807807
.add_negative(type_excluded_by_previous_patterns(db, predicate))
808808
.build();
809809

810+
let next_narrowed_subject_ty = IntersectionBuilder::new(db)
811+
.add_positive(narrowed_subject_ty)
812+
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
813+
.build();
814+
815+
if next_narrowed_subject_ty.is_never() {
816+
return Truthiness::AlwaysTrue;
817+
}
818+
810819
let truthiness = Self::analyze_single_pattern_predicate_kind(
811820
db,
812821
predicate.kind(db),

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,37 @@ impl<'db> IntersectionBuilder<'db> {
780780
seen_aliases,
781781
)
782782
}
783+
Type::EnumLiteral(enum_literal) => {
784+
let enum_class = enum_literal.enum_class(self.db);
785+
let metadata =
786+
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
787+
788+
let enum_members_in_negative_part = self
789+
.intersections
790+
.iter()
791+
.flat_map(|intersection| &intersection.negative)
792+
.filter_map(|ty| ty.as_enum_literal())
793+
.filter(|lit| lit.enum_class(self.db) == enum_class)
794+
.map(|lit| lit.name(self.db).clone())
795+
.chain(std::iter::once(enum_literal.name(self.db).clone()))
796+
.collect::<FxOrderSet<_>>();
797+
798+
let all_members_are_in_negative_part = metadata
799+
.members
800+
.keys()
801+
.all(|name| enum_members_in_negative_part.contains(name));
802+
803+
if all_members_are_in_negative_part {
804+
for inner in &mut self.intersections {
805+
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
806+
}
807+
} else {
808+
for inner in &mut self.intersections {
809+
inner.add_negative(self.db, ty);
810+
}
811+
}
812+
self
813+
}
783814
_ => {
784815
for inner in &mut self.intersections {
785816
inner.add_negative(self.db, ty);

0 commit comments

Comments
 (0)