From e3ea99303a1fc86f25db339bc0756db6ab66e935 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 8 Jul 2025 21:39:56 +0200 Subject: [PATCH 1/8] Cache inner contexts of overloads and binary ops --- mypy/checkexpr.py | 95 +++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..65d433a58a1f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -356,6 +356,9 @@ def __init__( self._arg_infer_context_cache = None + self.overload_stack_depth = 0 + self._args_cache = {} + def reset(self) -> None: self.resolved_type = {} @@ -1613,9 +1616,10 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - return self.check_overload_call( - callee, args, arg_kinds, arg_names, callable_name, object_type, context - ) + with self.overload_context(callee.name()): + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): @@ -1674,6 +1678,14 @@ def check_call( else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + @contextmanager + def overload_context(self, fn): + self.overload_stack_depth += 1 + yield + self.overload_stack_depth -= 1 + if self.overload_stack_depth == 0: + self._args_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] In short, we basically recurse on each argument without considering in what context the argument was called. """ + can_cache = not any(isinstance(t, TempNode) for t in args) + key = tuple(map(id, args)) + if can_cache and key in self._args_cache: + return self._args_cache[key] res: list[Type] = [] for arg in args: @@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] res.append(NoneType()) else: res.append(arg_type) + if can_cache: + self._args_cache[key] = res return res def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: @@ -2917,17 +2935,16 @@ def infer_overload_return_type( for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info, so we can @@ -3474,6 +3491,10 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) + + key = id(e) + if key in self._args_cache: + return self._args_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3543,28 +3564,30 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - method = operators.op_methods[e.op] - if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: - result, method_type = self.check_op( - method, - base_type=left_type, - arg=e.right, - context=e, - allow_reverse=use_reverse is UseReverse.DEFAULT, - ) - elif use_reverse is UseReverse.ALWAYS: - result, method_type = self.check_op( - # The reverse operator here gives better error messages: - operators.reverse_op_methods[method], - base_type=self.accept(e.right), - arg=e.left, - context=e, - allow_reverse=False, - ) - else: - assert_never(use_reverse) - e.method_type = method_type - return result + with self.overload_context(e.op): + method = operators.op_methods[e.op] + if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: + result, method_type = self.check_op( + method, + base_type=left_type, + arg=e.right, + context=e, + allow_reverse=use_reverse is UseReverse.DEFAULT, + ) + elif use_reverse is UseReverse.ALWAYS: + result, method_type = self.check_op( + # The reverse operator here gives better error messages: + operators.reverse_op_methods[method], + base_type=self.accept(e.right), + arg=e.left, + context=e, + allow_reverse=False, + ) + else: + assert_never(use_reverse) + e.method_type = method_type + self._args_cache[key] = result + return result else: raise RuntimeError(f"Unknown operator {e.op}") From 03af5b21d2c1c93b25fdd381687cc722bcf1c062 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 8 Jul 2025 21:59:14 +0200 Subject: [PATCH 2/8] Fix selfcheck --- mypy/checkexpr.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 65d433a58a1f..89d0d0d75b61 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -357,7 +357,9 @@ def __init__( self._arg_infer_context_cache = None self.overload_stack_depth = 0 - self._args_cache = {} + self.ops_stack_depth = 0 + self._args_cache: dict[tuple[int, ...], list[Type]] = {} + self._ops_cache: dict[int, Type] = {} def reset(self) -> None: self.resolved_type = {} @@ -1616,7 +1618,7 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - with self.overload_context(callee.name()): + with self.overload_context(): return self.check_overload_call( callee, args, arg_kinds, arg_names, callable_name, object_type, context ) @@ -1679,13 +1681,21 @@ def check_call( return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) @contextmanager - def overload_context(self, fn): + def overload_context(self) -> Iterator[None]: self.overload_stack_depth += 1 yield self.overload_stack_depth -= 1 if self.overload_stack_depth == 0: self._args_cache.clear() + @contextmanager + def ops_context(self) -> Iterator[None]: + self.ops_stack_depth += 1 + yield + self.ops_stack_depth -= 1 + if self.ops_stack_depth == 0: + self._ops_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -3493,8 +3503,8 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) key = id(e) - if key in self._args_cache: - return self._args_cache[key] + if key in self._ops_cache: + return self._ops_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3564,7 +3574,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - with self.overload_context(e.op): + with self.ops_context(): method = operators.op_methods[e.op] if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: result, method_type = self.check_op( @@ -3586,7 +3596,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: else: assert_never(use_reverse) e.method_type = method_type - self._args_cache[key] = result + self._ops_cache[key] = result return result else: raise RuntimeError(f"Unknown operator {e.op}") From 665a3119605ca9b3909ac37d2320385cef1f64ef Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 00:19:08 +0200 Subject: [PATCH 3/8] Only retain the overloads part --- mypy/checkexpr.py | 59 ++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 89d0d0d75b61..b7292d8c82c3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -357,9 +357,7 @@ def __init__( self._arg_infer_context_cache = None self.overload_stack_depth = 0 - self.ops_stack_depth = 0 self._args_cache: dict[tuple[int, ...], list[Type]] = {} - self._ops_cache: dict[int, Type] = {} def reset(self) -> None: self.resolved_type = {} @@ -1688,14 +1686,6 @@ def overload_context(self) -> Iterator[None]: if self.overload_stack_depth == 0: self._args_cache.clear() - @contextmanager - def ops_context(self) -> Iterator[None]: - self.ops_stack_depth += 1 - yield - self.ops_stack_depth -= 1 - if self.ops_stack_depth == 0: - self._ops_cache.clear() - def check_callable_call( self, callee: CallableType, @@ -3502,9 +3492,6 @@ def visit_op_expr(self, e: OpExpr) -> Type: if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) - key = id(e) - if key in self._ops_cache: - return self._ops_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3574,30 +3561,28 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - with self.ops_context(): - method = operators.op_methods[e.op] - if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: - result, method_type = self.check_op( - method, - base_type=left_type, - arg=e.right, - context=e, - allow_reverse=use_reverse is UseReverse.DEFAULT, - ) - elif use_reverse is UseReverse.ALWAYS: - result, method_type = self.check_op( - # The reverse operator here gives better error messages: - operators.reverse_op_methods[method], - base_type=self.accept(e.right), - arg=e.left, - context=e, - allow_reverse=False, - ) - else: - assert_never(use_reverse) - e.method_type = method_type - self._ops_cache[key] = result - return result + method = operators.op_methods[e.op] + if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: + result, method_type = self.check_op( + method, + base_type=left_type, + arg=e.right, + context=e, + allow_reverse=use_reverse is UseReverse.DEFAULT, + ) + elif use_reverse is UseReverse.ALWAYS: + result, method_type = self.check_op( + # The reverse operator here gives better error messages: + operators.reverse_op_methods[method], + base_type=self.accept(e.right), + arg=e.left, + context=e, + allow_reverse=False, + ) + else: + assert_never(use_reverse) + e.method_type = method_type + return result else: raise RuntimeError(f"Unknown operator {e.op}") From a1f8c7721e28d42089353420b727a5f4ee4ff099 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 16:57:36 +0200 Subject: [PATCH 4/8] Fix: the cache should not be touched outside of overloads --- mypy/checkexpr.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b7292d8c82c3..ce1c58c38b21 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -360,6 +360,8 @@ def __init__( self._args_cache: dict[tuple[int, ...], list[Type]] = {} def reset(self) -> None: + assert self.overload_stack_depth == 0 + assert not self._args_cache self.resolved_type = {} def visit_name_expr(self, e: NameExpr) -> Type: @@ -1949,7 +1951,12 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] In short, we basically recurse on each argument without considering in what context the argument was called. """ - can_cache = not any(isinstance(t, TempNode) for t in args) + # We can only use this hack locally while checking a single nested overloaded + # call. This saves a lot of rechecking, but is not generally safe. Cache is + # pruned upon leaving the outermost overload. + can_cache = self.overload_stack_depth > 0 and not any( + isinstance(t, TempNode) for t in args + ) key = tuple(map(id, args)) if can_cache and key in self._args_cache: return self._args_cache[key] From 479550b97e33b0ef0c361042cc9610e4f77581c0 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 18:14:05 +0200 Subject: [PATCH 5/8] Poison cache when we encounter any lambda --- mypy/checkexpr.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ce1c58c38b21..95c863407b37 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -233,6 +233,8 @@ "builtins.memoryview", } +POISON_KEY: Final = (-1,) + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -1954,8 +1956,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] # We can only use this hack locally while checking a single nested overloaded # call. This saves a lot of rechecking, but is not generally safe. Cache is # pruned upon leaving the outermost overload. - can_cache = self.overload_stack_depth > 0 and not any( - isinstance(t, TempNode) for t in args + can_cache = ( + self.overload_stack_depth > 0 + and POISON_KEY not in self._args_cache + and not any(isinstance(t, TempNode) for t in args) ) key = tuple(map(id, args)) if can_cache and key in self._args_cache: @@ -5426,6 +5430,9 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + if self.overload_stack_depth > 0: + # Poison cache when we encounter lambdas - it isn't safe to cache their types. + self._args_cache[POISON_KEY] = [] self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: From 8904125f4e919101abdd5dfbe3c6b36018cc0e61 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 18 Jul 2025 17:17:59 +0200 Subject: [PATCH 6/8] Only cache when explicitly requested --- mypy/checkexpr.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 95c863407b37..5ce9447a64a5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1947,7 +1947,9 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: self.msg.unsupported_type_type(item, context) return AnyType(TypeOfAny.from_error) - def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]: + def infer_arg_types_in_empty_context( + self, args: list[Expression], *, allow_cache: bool + ) -> list[Type]: """Infer argument expression types in an empty context. In short, we basically recurse on each argument without considering @@ -1957,7 +1959,7 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] # call. This saves a lot of rechecking, but is not generally safe. Cache is # pruned upon leaving the outermost overload. can_cache = ( - self.overload_stack_depth > 0 + allow_cache and POISON_KEY not in self._args_cache and not any(isinstance(t, TempNode) for t in args) ) @@ -2737,7 +2739,7 @@ def check_overload_call( """Checks a call to an overloaded function.""" # Normalize unpacked kwargs before checking the call. callee = callee.with_unpacked_kwargs() - arg_types = self.infer_arg_types_in_empty_context(args) + arg_types = self.infer_arg_types_in_empty_context(args, allow_cache=True) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( arg_types, arg_kinds, arg_names, callee @@ -3331,7 +3333,7 @@ def apply_generic_arguments( ) def check_any_type_call(self, args: list[Expression], callee: Type) -> tuple[Type, Type]: - self.infer_arg_types_in_empty_context(args) + self.infer_arg_types_in_empty_context(args, allow_cache=False) callee = get_proper_type(callee) if isinstance(callee, AnyType): return ( From 1cbfab901e16eb04263a0549d4004de36260c816 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 18 Jul 2025 17:19:11 +0200 Subject: [PATCH 7/8] Poison cache in presence of assignments --- mypy/checkexpr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5ce9447a64a5..64191fd9cf2e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4377,6 +4377,9 @@ def check_list_multiply(self, e: OpExpr) -> Type: return result def visit_assignment_expr(self, e: AssignmentExpr) -> Type: + if self.overload_stack_depth > 0: + # Poison cache when we encounter assignments in overloads - they affect the binder. + self._args_cache[POISON_KEY] = [] value = self.accept(e.value) self.chk.check_assignment(e.target, e.value) self.chk.check_final(e) From 72fdf5d29c98fc5177131996ea57c09443f57818 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 18 Jul 2025 17:22:21 +0200 Subject: [PATCH 8/8] Do not cache if errors are encountered in arguments themselves --- mypy/checkexpr.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 64191fd9cf2e..3af50d04216c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1968,13 +1968,18 @@ def infer_arg_types_in_empty_context( return self._args_cache[key] res: list[Type] = [] - for arg in args: - arg_type = self.accept(arg) - if has_erased_component(arg_type): - res.append(NoneType()) - else: - res.append(arg_type) - if can_cache: + with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w: + for arg in args: + arg_type = self.accept(arg) + if has_erased_component(arg_type): + res.append(NoneType()) + else: + res.append(arg_type) + + if w.has_new_errors(): + self.msg.add_errors(w.filtered_errors()) + elif can_cache: + # Do not cache if new diagnostics were emitted: they may impact parent overload self._args_cache[key] = res return res