diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..d9e4a704155c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4155,10 +4155,9 @@ def check_op( """ if allow_reverse: - left_variants = [base_type] + left_variants = self._union_items_from_typevar(base_type) base_type = get_proper_type(base_type) - if isinstance(base_type, UnionType): - left_variants = list(flatten_nested_unions(base_type.relevant_items())) + right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -4196,13 +4195,17 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [(right_type, arg)] - right_type = get_proper_type(right_type) - if isinstance(right_type, UnionType): + right_variants: list[tuple[Type, Expression]] + p_right = get_proper_type(right_type) + if isinstance(p_right, (UnionType, TypeVarType)): right_variants = [ (item, TempNode(item, context=context)) - for item in flatten_nested_unions(right_type.relevant_items()) + for item in self._union_items_from_typevar(right_type) ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = p_right all_results = [] all_inferred = [] @@ -4252,6 +4255,20 @@ def check_op( context=context, ) + def _union_items_from_typevar(self, typ: Type) -> list[Type]: + variants = [typ] + typ = get_proper_type(typ) + base_type = typ + if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): + typ = get_proper_type(typ.upper_bound) + if is_union := isinstance(typ, UnionType): + variants = list(flatten_nested_unions(typ.relevant_items())) + if is_union and unwrapped: + # If not a union, keep the original type + assert isinstance(base_type, TypeVarType) + variants = [base_type.copy_modified(upper_bound=item) for item in variants] + return variants + def check_boolean_op(self, e: OpExpr, context: Context) -> Type: """Type check a boolean operation ('and' or 'or').""" diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..5c2cad914bce 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -706,6 +706,32 @@ if int(): class C: def __lt__(self, o: object, x: str = "") -> int: ... +[case testReversibleOpOnTypeVarBound] +from typing import TypeVar, Union + +class A: + def __lt__(self, a: A) -> bool: ... + def __gt__(self, a: A) -> bool: ... + +class B(A): + def __lt__(self, b: B) -> bool: ... # type: ignore[override] + def __gt__(self, b: B) -> bool: ... # type: ignore[override] + +_T = TypeVar("_T", bound=Union[A, B]) + +def check(x: _T, y: _T) -> bool: + return x < y + +[case testReversibleOpOnTypeVarBoundPromotion] +from typing import TypeVar, Union + +_T = TypeVar("_T", bound=Union[int, float]) + +def check(x: _T, y: _T) -> bool: + return x < y +[builtins fixtures/ops.pyi] + + [case testErrorContextAndBinaryOperators] import typing class A: diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 67bc74b35c51..34e512b34984 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -61,6 +61,12 @@ class float: def __rdiv__(self, x: 'float') -> 'float': pass def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'float') -> bool: pass + def __le__(self, x: 'float') -> bool: pass + def __gt__(self, x: 'float') -> bool: pass + def __ge__(self, x: 'float') -> bool: pass class complex: def __add__(self, x: complex) -> complex: pass