Skip to content

[match-case] Fix narrowing of class pattern with union-argument. #19473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7994,20 +7994,47 @@ def conditional_types(
) -> tuple[Type | None, Type | None]:
"""Takes in the current type and a proposed type of an expression.

Returns a 2-tuple: The first element is the proposed type, if the expression
can be the proposed type. The second element is the type it would hold
if it was not the proposed type, if any. UninhabitedType means unreachable.
None means no new information can be inferred. If default is set it is returned
instead."""
Returns a 2-tuple:
The first element is the proposed type, if the expression can be the proposed type.
The second element is the type it would hold if it was not the proposed type, if any.
UninhabitedType means unreachable.
None means no new information can be inferred.
If default is set it is returned instead.
"""
if proposed_type_ranges and len(proposed_type_ranges) == 1:
# expand e.g. bool -> Literal[True] | Literal[False]
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type, enum_name)

current_type = get_proper_type(current_type)
if isinstance(current_type, UnionType) and (default == current_type):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default == current_type was necessary, otherwise quite a few tests break.
Adding or (default is None) and passing default=None in the recursive call also seems to break stuff. Not entirely sure what's going on there.

# factorize over union types
# if we try to narrow A|B to C, we instead narrow A to C and B to C, and
# return the union of the results
result: list[tuple[Type | None, Type | None]] = [
conditional_types(
union_item,
proposed_type_ranges,
default=union_item,
consider_runtime_isinstance=consider_runtime_isinstance,
)
for union_item in get_proper_types(current_type.items)
]
# separate list of tuples into two lists
yes_types, no_types = zip(*result)
yes_type = make_simplified_union([t for t in yes_types if t is not None])
no_type = restrict_subtype_away(
current_type, yes_type, consider_runtime_isinstance=consider_runtime_isinstance
)

return yes_type, no_type

if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,22 @@ def union(x: str | bool) -> None:
reveal_type(x) # N: Revealed type is "Union[builtins.str, Literal[False]]"
[builtins fixtures/tuple.pyi]

[case testMatchNarrowDownUnionUsingClassPattern]

class Foo: ...
class Bar(Foo): ...

def test_1(bar: Bar) -> None:
match bar:
case Foo() as foo:
reveal_type(foo) # N: Revealed type is "__main__.Bar"

def test_2(bar: Bar | str) -> None:
match bar:
case Foo() as foo:
reveal_type(foo) # N: Revealed type is "__main__.Bar"


[case testMatchAssertFalseToSilenceFalsePositives]
class C:
a: int | str
Expand Down