diff --git a/mypy/checker.py b/mypy/checker.py index 7579c36a97d0..f6ae27ed3e31 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7994,20 +7994,47 @@ def conditional_types( ) -> tuple[Type | None, Type | None]: """Takes in the current type and a proposed type of an expression. - Returns a 2-tuple: The first element is the proposed type, if the expression - can be the proposed type. The second element is the type it would hold - if it was not the proposed type, if any. UninhabitedType means unreachable. - None means no new information can be inferred. If default is set it is returned - instead.""" + Returns a 2-tuple: + The first element is the proposed type, if the expression can be the proposed type. + The second element is the type it would hold if it was not the proposed type, if any. + UninhabitedType means unreachable. + None means no new information can be inferred. + If default is set it is returned instead. + """ + if proposed_type_ranges and len(proposed_type_ranges) == 1: + # expand e.g. bool -> Literal[True] | Literal[False] + target = proposed_type_ranges[0].item + target = get_proper_type(target) + if isinstance(target, LiteralType) and ( + target.is_enum_literal() or isinstance(target.value, bool) + ): + enum_name = target.fallback.type.fullname + current_type = try_expanding_sum_type_to_union(current_type, enum_name) + + current_type = get_proper_type(current_type) + if isinstance(current_type, UnionType) and (default == current_type): + # factorize over union types + # if we try to narrow A|B to C, we instead narrow A to C and B to C, and + # return the union of the results + result: list[tuple[Type | None, Type | None]] = [ + conditional_types( + union_item, + proposed_type_ranges, + default=union_item, + consider_runtime_isinstance=consider_runtime_isinstance, + ) + for union_item in get_proper_types(current_type.items) + ] + # separate list of tuples into two lists + yes_types, no_types = zip(*result) + yes_type = make_simplified_union([t for t in yes_types if t is not None]) + no_type = restrict_subtype_away( + current_type, yes_type, consider_runtime_isinstance=consider_runtime_isinstance + ) + + return yes_type, no_type + if proposed_type_ranges: - if len(proposed_type_ranges) == 1: - target = proposed_type_ranges[0].item - target = get_proper_type(target) - if isinstance(target, LiteralType) and ( - target.is_enum_literal() or isinstance(target.value, bool) - ): - enum_name = target.fallback.type.fullname - current_type = try_expanding_sum_type_to_union(current_type, enum_name) proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) if isinstance(proposed_type, AnyType): diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index bb8f038eb1eb..d8f8fa592204 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1760,6 +1760,22 @@ def union(x: str | bool) -> None: reveal_type(x) # N: Revealed type is "Union[builtins.str, Literal[False]]" [builtins fixtures/tuple.pyi] +[case testMatchNarrowDownUnionUsingClassPattern] + +class Foo: ... +class Bar(Foo): ... + +def test_1(bar: Bar) -> None: + match bar: + case Foo() as foo: + reveal_type(foo) # N: Revealed type is "__main__.Bar" + +def test_2(bar: Bar | str) -> None: + match bar: + case Foo() as foo: + reveal_type(foo) # N: Revealed type is "__main__.Bar" + + [case testMatchAssertFalseToSilenceFalsePositives] class C: a: int | str