From 412f63bfe1a8bcc0682f5986b0cf462e59bcae0d Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 03:31:28 +0200 Subject: [PATCH 1/3] Apply union expansion when checking ops to typevars --- mypy/checkexpr.py | 31 +++++++++++++++++++++------ test-data/unit/check-expressions.test | 26 ++++++++++++++++++++++ test-data/unit/fixtures/ops.pyi | 6 ++++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..9a9df97be358 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4155,10 +4155,9 @@ def check_op( """ if allow_reverse: - left_variants = [base_type] + left_variants = self._union_items_from_typevar(base_type) base_type = get_proper_type(base_type) - if isinstance(base_type, UnionType): - left_variants = list(flatten_nested_unions(base_type.relevant_items())) + right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -4196,13 +4195,18 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [(right_type, arg)] - right_type = get_proper_type(right_type) - if isinstance(right_type, UnionType): + right_variants: list[tuple[Type, Expression]] + if isinstance(right_type, ProperType) and isinstance( + right_type, (UnionType, TypeVarType) + ): right_variants = [ (item, TempNode(item, context=context)) - for item in flatten_nested_unions(right_type.relevant_items()) + for item in self._union_items_from_typevar(right_type) ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = get_proper_type(right_type) all_results = [] all_inferred = [] @@ -4252,6 +4256,19 @@ def check_op( context=context, ) + def _union_items_from_typevar(self, typ: Type) -> list[Type]: + variants = [typ] + typ = get_proper_type(typ) + base_type = typ + if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): + typ = get_proper_type(typ.upper_bound) + if isinstance(typ, UnionType): + variants = list(flatten_nested_unions(typ.relevant_items())) + if unwrapped: + assert isinstance(base_type, TypeVarType) + variants = [base_type.copy_modified(upper_bound=item) for item in variants] + return variants + def check_boolean_op(self, e: OpExpr, context: Context) -> Type: """Type check a boolean operation ('and' or 'or').""" diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..5c2cad914bce 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -706,6 +706,32 @@ if int(): class C: def __lt__(self, o: object, x: str = "") -> int: ... +[case testReversibleOpOnTypeVarBound] +from typing import TypeVar, Union + +class A: + def __lt__(self, a: A) -> bool: ... + def __gt__(self, a: A) -> bool: ... + +class B(A): + def __lt__(self, b: B) -> bool: ... # type: ignore[override] + def __gt__(self, b: B) -> bool: ... # type: ignore[override] + +_T = TypeVar("_T", bound=Union[A, B]) + +def check(x: _T, y: _T) -> bool: + return x < y + +[case testReversibleOpOnTypeVarBoundPromotion] +from typing import TypeVar, Union + +_T = TypeVar("_T", bound=Union[int, float]) + +def check(x: _T, y: _T) -> bool: + return x < y +[builtins fixtures/ops.pyi] + + [case testErrorContextAndBinaryOperators] import typing class A: diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 67bc74b35c51..34e512b34984 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -61,6 +61,12 @@ class float: def __rdiv__(self, x: 'float') -> 'float': pass def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'float') -> bool: pass + def __le__(self, x: 'float') -> bool: pass + def __gt__(self, x: 'float') -> bool: pass + def __ge__(self, x: 'float') -> bool: pass class complex: def __add__(self, x: complex) -> complex: pass From fdab83007fae33a93020f282887bd14c911b70e4 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 03:58:44 +0200 Subject: [PATCH 2/3] Preserve type identity --- mypy/checkexpr.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9a9df97be358..3c870f90ee2b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4195,17 +4195,10 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants: list[tuple[Type, Expression]] - if isinstance(right_type, ProperType) and isinstance( - right_type, (UnionType, TypeVarType) - ): - right_variants = [ - (item, TempNode(item, context=context)) - for item in self._union_items_from_typevar(right_type) - ] - else: - # Preserve argument identity if we do not intend to modify it - right_variants = [(right_type, arg)] + right_variants = [ + (item, TempNode(item, context=context)) + for item in self._union_items_from_typevar(right_type) + ] right_type = get_proper_type(right_type) all_results = [] @@ -4262,9 +4255,10 @@ def _union_items_from_typevar(self, typ: Type) -> list[Type]: base_type = typ if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): typ = get_proper_type(typ.upper_bound) - if isinstance(typ, UnionType): + if is_union := isinstance(typ, UnionType): variants = list(flatten_nested_unions(typ.relevant_items())) - if unwrapped: + if is_union and unwrapped: + # If not a union, keep the original type assert isinstance(base_type, TypeVarType) variants = [base_type.copy_modified(upper_bound=item) for item in variants] return variants From 30c0e32e81c0dd7415b09d86dbee381a5efd9921 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 04:39:21 +0200 Subject: [PATCH 3/3] Retain original arg if possible --- mypy/checkexpr.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3c870f90ee2b..d9e4a704155c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4195,11 +4195,17 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [ - (item, TempNode(item, context=context)) - for item in self._union_items_from_typevar(right_type) - ] - right_type = get_proper_type(right_type) + right_variants: list[tuple[Type, Expression]] + p_right = get_proper_type(right_type) + if isinstance(p_right, (UnionType, TypeVarType)): + right_variants = [ + (item, TempNode(item, context=context)) + for item in self._union_items_from_typevar(right_type) + ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = p_right all_results = [] all_inferred = []