diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 24f0c8c85d61..2b042b4c0c7c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -233,6 +233,8 @@ "builtins.memoryview", } +POISON_KEY: Final = (-1,) + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -356,7 +358,12 @@ def __init__( self._arg_infer_context_cache = None + self.overload_stack_depth = 0 + self._args_cache: dict[tuple[int, ...], list[Type]] = {} + def reset(self) -> None: + assert self.overload_stack_depth == 0 + assert not self._args_cache self.resolved_type = {} def visit_name_expr(self, e: NameExpr) -> Type: @@ -1613,9 +1620,10 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - return self.check_overload_call( - callee, args, arg_kinds, arg_names, callable_name, object_type, context - ) + with self.overload_context(): + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): @@ -1678,6 +1686,14 @@ def check_call( else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + @contextmanager + def overload_context(self) -> Iterator[None]: + self.overload_stack_depth += 1 + yield + self.overload_stack_depth -= 1 + if self.overload_stack_depth == 0: + self._args_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -1935,20 +1951,40 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: self.msg.unsupported_type_type(item, context) return AnyType(TypeOfAny.from_error) - def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]: + def infer_arg_types_in_empty_context( + self, args: list[Expression], *, allow_cache: bool + ) -> list[Type]: """Infer argument expression types in an empty context. In short, we basically recurse on each argument without considering in what context the argument was called. """ + # We can only use this hack locally while checking a single nested overloaded + # call. This saves a lot of rechecking, but is not generally safe. Cache is + # pruned upon leaving the outermost overload. + can_cache = ( + allow_cache + and POISON_KEY not in self._args_cache + and not any(isinstance(t, TempNode) for t in args) + ) + key = tuple(map(id, args)) + if can_cache and key in self._args_cache: + return self._args_cache[key] res: list[Type] = [] - for arg in args: - arg_type = self.accept(arg) - if has_erased_component(arg_type): - res.append(NoneType()) - else: - res.append(arg_type) + with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w: + for arg in args: + arg_type = self.accept(arg) + if has_erased_component(arg_type): + res.append(NoneType()) + else: + res.append(arg_type) + + if w.has_new_errors(): + self.msg.add_errors(w.filtered_errors()) + elif can_cache: + # Do not cache if new diagnostics were emitted: they may impact parent overload + self._args_cache[key] = res return res def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: @@ -2712,7 +2748,7 @@ def check_overload_call( """Checks a call to an overloaded function.""" # Normalize unpacked kwargs before checking the call. callee = callee.with_unpacked_kwargs() - arg_types = self.infer_arg_types_in_empty_context(args) + arg_types = self.infer_arg_types_in_empty_context(args, allow_cache=True) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( arg_types, arg_kinds, arg_names, callee @@ -2921,17 +2957,16 @@ def infer_overload_return_type( for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info, so we can @@ -3307,7 +3342,7 @@ def apply_generic_arguments( ) def check_any_type_call(self, args: list[Expression], callee: Type) -> tuple[Type, Type]: - self.infer_arg_types_in_empty_context(args) + self.infer_arg_types_in_empty_context(args, allow_cache=False) callee = get_proper_type(callee) if isinstance(callee, AnyType): return ( @@ -3478,6 +3513,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) + left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -4350,6 +4386,9 @@ def check_list_multiply(self, e: OpExpr) -> Type: return result def visit_assignment_expr(self, e: AssignmentExpr) -> Type: + if self.overload_stack_depth > 0: + # Poison cache when we encounter assignments in overloads - they affect the binder. + self._args_cache[POISON_KEY] = [] value = self.accept(e.value) self.chk.check_assignment(e.target, e.value) self.chk.check_final(e) @@ -5405,6 +5444,9 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + if self.overload_stack_depth > 0: + # Poison cache when we encounter lambdas - it isn't safe to cache their types. + self._args_cache[POISON_KEY] = [] self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: