File tree Expand file tree Collapse file tree 4 files changed +109
-0
lines changed
crates/ty_python_semantic Expand file tree Collapse file tree 4 files changed +109
-0
lines changed Original file line number Diff line number Diff line change @@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
379379 # this diagnostic is correct: the inferred type of `subject` is `str`
380380 assert_never(subject) # error: [type-assertion-failure]
381381```
382+
383+ ## Exhaustiveness checking for methods of enums
384+
385+ ``` py
386+ from enum import Enum
387+
388+ class Answer (Enum ):
389+ YES = " yes"
390+ NO = " no"
391+
392+ def is_yes (self ) -> bool :
393+ reveal_type(self ) # revealed: Self@is_yes
394+
395+ match self :
396+ case Answer.YES :
397+ return True
398+ case Answer.NO :
399+ return False
400+ ```
Original file line number Diff line number Diff line change @@ -252,3 +252,53 @@ match x:
252252
253253reveal_type(x) # revealed: object
254254```
255+
256+ ## Narrowing on ` Self ` in ` match ` statements
257+
258+ When performing narrowing on ` self ` inside methods on enums, we take into account that ` Self ` might
259+ refer to a subtype of the enum class, like ` Literal[Answer.YES] ` . This is why we do not simplify
260+ ` Self & ~Literal[Answer.YES] ` to ` Literal[Answer.NO, Answer.MAYBE] ` . Otherwise, we wouldn't be able
261+ to return ` self ` in the ` assert_yes ` method below:
262+
263+ ``` py
264+ from enum import Enum
265+ from typing_extensions import Self, assert_never
266+
267+ class Answer (Enum ):
268+ NO = 0
269+ YES = 1
270+ MAYBE = 2
271+
272+ def is_yes (self ) -> bool :
273+ reveal_type(self ) # revealed: Self@is_yes
274+
275+ match self :
276+ case Answer.YES :
277+ reveal_type(self ) # revealed: Self@is_yes
278+ return True
279+ case Answer.NO | Answer.MAYBE :
280+ reveal_type(self ) # revealed: Self@is_yes & ~Literal[Answer.YES]
281+ return False
282+ case _:
283+ assert_never(self ) # no error
284+
285+ def assert_yes (self ) -> Self:
286+ reveal_type(self ) # revealed: Self@assert_yes
287+
288+ match self :
289+ case Answer.YES :
290+ reveal_type(self ) # revealed: Self@assert_yes
291+ return self
292+ case _:
293+ reveal_type(self ) # revealed: Self@assert_yes & ~Literal[Answer.YES]
294+ raise ValueError (" Answer is not YES" )
295+
296+ Answer.YES .is_yes()
297+
298+ try :
299+ reveal_type(Answer.MAYBE .assert_yes()) # revealed: Literal[Answer.MAYBE]
300+ except ValueError :
301+ pass
302+ ```
303+
304+ We do
Original file line number Diff line number Diff line change @@ -807,6 +807,15 @@ impl ReachabilityConstraints {
807807 . add_negative ( type_excluded_by_previous_patterns ( db, predicate) )
808808 . build ( ) ;
809809
810+ let next_narrowed_subject_ty = IntersectionBuilder :: new ( db)
811+ . add_positive ( narrowed_subject_ty)
812+ . add_negative ( pattern_kind_to_type ( db, predicate. kind ( db) ) )
813+ . build ( ) ;
814+
815+ if next_narrowed_subject_ty. is_never ( ) {
816+ return Truthiness :: AlwaysTrue ;
817+ }
818+
810819 let truthiness = Self :: analyze_single_pattern_predicate_kind (
811820 db,
812821 predicate. kind ( db) ,
Original file line number Diff line number Diff line change @@ -780,6 +780,37 @@ impl<'db> IntersectionBuilder<'db> {
780780 seen_aliases,
781781 )
782782 }
783+ Type :: EnumLiteral ( enum_literal) => {
784+ let enum_class = enum_literal. enum_class ( self . db ) ;
785+ let metadata =
786+ enum_metadata ( self . db , enum_class) . expect ( "Class of enum literal is an enum" ) ;
787+
788+ let enum_members_in_negative_part = self
789+ . intersections
790+ . iter ( )
791+ . flat_map ( |intersection| & intersection. negative )
792+ . filter_map ( |ty| ty. as_enum_literal ( ) )
793+ . filter ( |lit| lit. enum_class ( self . db ) == enum_class)
794+ . map ( |lit| lit. name ( self . db ) . clone ( ) )
795+ . chain ( std:: iter:: once ( enum_literal. name ( self . db ) . clone ( ) ) )
796+ . collect :: < FxOrderSet < _ > > ( ) ;
797+
798+ let all_members_are_in_negative_part = metadata
799+ . members
800+ . keys ( )
801+ . all ( |name| enum_members_in_negative_part. contains ( name) ) ;
802+
803+ if all_members_are_in_negative_part {
804+ for inner in & mut self . intersections {
805+ inner. add_negative ( self . db , enum_literal. enum_class_instance ( self . db ) ) ;
806+ }
807+ } else {
808+ for inner in & mut self . intersections {
809+ inner. add_negative ( self . db , ty) ;
810+ }
811+ }
812+ self
813+ }
783814 _ => {
784815 for inner in & mut self . intersections {
785816 inner. add_negative ( self . db , ty) ;
You can’t perform that action at this time.
0 commit comments