Skip to content

Commit c502b02

Browse files
committed
Add functools.lru_cache plugin support
- Add lru_cache callback to functools plugin for type validation - Register callbacks in default plugin for decorator and wrapper calls - Support different lru_cache patterns: @lru_cache, @lru_cache(), @lru_cache(maxsize=N) Fixes issue #16261
1 parent db67888 commit c502b02

File tree

4 files changed

+238
-0
lines changed

4 files changed

+238
-0
lines changed

mypy/plugins/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@
4949
)
5050
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
5151
from mypy.plugins.functools import (
52+
functools_lru_cache_callback,
5253
functools_total_ordering_maker_callback,
5354
functools_total_ordering_makers,
55+
lru_cache_wrapper_call_callback,
5456
partial_call_callback,
5557
partial_new_callback,
5658
)
@@ -101,6 +103,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
101103
return create_singledispatch_function_callback
102104
elif fullname == "functools.partial":
103105
return partial_new_callback
106+
elif fullname == "functools.lru_cache":
107+
return functools_lru_cache_callback
104108
elif fullname == "enum.member":
105109
return enum_member_callback
106110
return None
@@ -160,6 +164,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
160164
return call_singledispatch_function_after_register_argument
161165
elif fullname == "functools.partial.__call__":
162166
return partial_call_callback
167+
elif fullname == "functools._lru_cache_wrapper.__call__":
168+
return lru_cache_wrapper_call_callback
163169
return None
164170

165171
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

mypy/plugins/functools.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
4242

4343
PARTIAL: Final = "functools.partial"
44+
LRU_CACHE: Final = "functools.lru_cache"
4445

4546

4647
class _MethodInfo(NamedTuple):
@@ -393,3 +394,119 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
393394
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
394395

395396
return result
397+
398+
399+
def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
400+
"""Infer a more precise return type for functools.lru_cache decorator"""
401+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
402+
return ctx.default_return_type
403+
404+
# Only handle the simple case: @lru_cache (without parentheses)
405+
# where a function is passed directly as the first argument
406+
if (len(ctx.arg_types) >= 1 and
407+
len(ctx.arg_types[0]) == 1 and
408+
len(ctx.arg_types) <= 2): # Ensure we don't have extra args indicating parameterized call
409+
410+
first_arg_type = ctx.arg_types[0][0]
411+
412+
# Explicitly check that this is NOT a literal or other non-function type
413+
from mypy.types import LiteralType, Instance
414+
if isinstance(first_arg_type, (LiteralType, Instance)):
415+
# This is likely maxsize=128 or similar - let MyPy handle it
416+
return ctx.default_return_type
417+
418+
# Try to extract callable type
419+
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
420+
if fn_type is not None:
421+
# This is the @lru_cache case (function passed directly)
422+
return fn_type
423+
424+
# For all parameterized cases, don't interfere
425+
return ctx.default_return_type
426+
427+
428+
def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
429+
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
430+
if not isinstance(ctx.api, mypy.checker.TypeChecker):
431+
return ctx.default_return_type
432+
433+
# Try to find the original function signature using AST/symbol table analysis
434+
original_signature = _find_original_function_signature(ctx)
435+
436+
if original_signature is not None:
437+
# Validate the call against the original function signature
438+
actual_args = []
439+
actual_arg_kinds = []
440+
actual_arg_names = []
441+
seen_args = set()
442+
443+
for i, param in enumerate(ctx.args):
444+
for j, a in enumerate(param):
445+
if a in seen_args:
446+
continue
447+
seen_args.add(a)
448+
actual_args.append(a)
449+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
450+
actual_arg_names.append(ctx.arg_names[i][j])
451+
452+
# Check the call against the original signature
453+
result, _ = ctx.api.expr_checker.check_call(
454+
callee=original_signature,
455+
args=actual_args,
456+
arg_kinds=actual_arg_kinds,
457+
arg_names=actual_arg_names,
458+
context=ctx.context,
459+
)
460+
return result
461+
462+
return ctx.default_return_type
463+
464+
465+
def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
466+
"""
467+
Attempt to find the original function signature from the call context.
468+
469+
Returns the CallableType of the original function if found, None otherwise.
470+
This function safely traverses the AST structure to locate the original
471+
function signature that was decorated with @lru_cache.
472+
"""
473+
from mypy.nodes import CallExpr, Decorator, NameExpr
474+
475+
# Ensure we have the required context structure
476+
if not isinstance(ctx.context, CallExpr):
477+
return None
478+
479+
callee = ctx.context.callee
480+
if not isinstance(callee, NameExpr) or not callee.name:
481+
return None
482+
483+
func_name = callee.name
484+
485+
# Safely access the API globals
486+
if not hasattr(ctx.api, 'globals') or not isinstance(ctx.api.globals, dict):
487+
return None
488+
489+
if func_name not in ctx.api.globals:
490+
return None
491+
492+
symbol = ctx.api.globals[func_name]
493+
494+
# Validate symbol structure before accessing node
495+
if not hasattr(symbol, 'node') or symbol.node is None:
496+
return None
497+
498+
# Check if this is a decorator node containing our function
499+
if isinstance(symbol.node, Decorator):
500+
decorator_node = symbol.node
501+
502+
# Safely access the decorated function
503+
if not hasattr(decorator_node, 'func') or decorator_node.func is None:
504+
return None
505+
506+
func_def = decorator_node.func
507+
508+
# Verify we have a callable type
509+
if hasattr(func_def, 'type') and isinstance(func_def.type, CallableType):
510+
return func_def.type
511+
512+
return None

test-data/unit/check-functools.test

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,110 @@ def outer_c(arg: Tc) -> None:
726726
use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \
727727
# N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]"
728728
[builtins fixtures/tuple.pyi]
729+
730+
[case testLruCacheBasicValidation]
731+
from functools import lru_cache
732+
733+
@lru_cache
734+
def f(v: str, at: int) -> str:
735+
return v
736+
737+
f() # E: Missing positional arguments "v", "at" in call to "f"
738+
f("abc") # E: Missing positional argument "at" in call to "f"
739+
f("abc", 123) # OK
740+
f("abc", at=123) # OK
741+
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
742+
[builtins fixtures/dict.pyi]
743+
744+
[case testLruCacheWithReturnType]
745+
from functools import lru_cache
746+
747+
@lru_cache
748+
def multiply(x: int, y: int) -> int:
749+
return 42
750+
751+
reveal_type(multiply) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
752+
reveal_type(multiply(2, 3)) # N: Revealed type is "builtins.int"
753+
multiply("a", 3) # E: Argument 1 to "multiply" has incompatible type "str"; expected "int"
754+
multiply(2, "b") # E: Argument 2 to "multiply" has incompatible type "str"; expected "int"
755+
multiply(2) # E: Missing positional argument "y" in call to "multiply"
756+
multiply(1, 2, 3) # E: Too many arguments for "multiply"
757+
[builtins fixtures/dict.pyi]
758+
759+
[case testLruCacheWithOptionalArgs]
760+
from functools import lru_cache
761+
762+
@lru_cache
763+
def greet(name: str, greeting: str = "Hello") -> str:
764+
return "result"
765+
766+
greet("World") # OK
767+
greet("World", "Hi") # OK
768+
greet("World", greeting="Hi") # OK
769+
greet() # E: Missing positional argument "name" in call to "greet"
770+
greet(123) # E: Argument 1 to "greet" has incompatible type "int"; expected "str"
771+
greet("World", 123) # E: Argument 2 to "greet" has incompatible type "int"; expected "str"
772+
[builtins fixtures/dict.pyi]
773+
774+
[case testLruCacheGenericFunction]
775+
from functools import lru_cache
776+
from typing import TypeVar
777+
778+
T = TypeVar('T')
779+
780+
@lru_cache
781+
def identity(x: T) -> T:
782+
return x
783+
784+
reveal_type(identity(42)) # N: Revealed type is "builtins.int"
785+
reveal_type(identity("hello")) # N: Revealed type is "builtins.str"
786+
identity() # E: Missing positional argument "x" in call to "identity"
787+
[builtins fixtures/dict.pyi]
788+
789+
[case testLruCacheWithParentheses]
790+
from functools import lru_cache
791+
792+
@lru_cache()
793+
def f(v: str, at: int) -> str:
794+
return v
795+
796+
f() # E: Missing positional arguments "v", "at" in call to "f"
797+
f("abc") # E: Missing positional argument "at" in call to "f"
798+
f("abc", 123) # OK
799+
f("abc", at=123) # OK
800+
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
801+
[builtins fixtures/dict.pyi]
802+
803+
[case testLruCacheWithMaxsize]
804+
from functools import lru_cache
805+
806+
@lru_cache(maxsize=128)
807+
def g(v: str, at: int) -> str:
808+
return v
809+
810+
g() # E: Missing positional arguments "v", "at" in call to "g"
811+
g("abc") # E: Missing positional argument "at" in call to "g"
812+
g("abc", 123) # OK
813+
g("abc", at=123) # OK
814+
g("abc", at="wrong_type") # E: Argument "at" to "g" has incompatible type "str"; expected "int"
815+
[builtins fixtures/dict.pyi]
816+
817+
[case testLruCacheGenericWithParameters]
818+
from functools import lru_cache
819+
from typing import TypeVar
820+
821+
T = TypeVar('T')
822+
823+
@lru_cache()
824+
def identity_empty(x: T) -> T:
825+
return x
826+
827+
@lru_cache(maxsize=128)
828+
def identity_maxsize(x: T) -> T:
829+
return x
830+
831+
reveal_type(identity_empty(42)) # N: Revealed type is "builtins.int"
832+
reveal_type(identity_maxsize("hello")) # N: Revealed type is "builtins.str"
833+
identity_empty() # E: Missing positional argument "x" in call to "identity_empty"
834+
identity_maxsize() # E: Missing positional argument "x" in call to "identity_maxsize"
835+
[builtins fixtures/dict.pyi]

test-data/unit/lib-stub/functools.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,11 @@ class cached_property(Generic[_T]):
3737
class partial(Generic[_T]):
3838
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
3939
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...
40+
41+
class _lru_cache_wrapper(Generic[_T]):
42+
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...
43+
44+
@overload
45+
def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ...
46+
@overload
47+
def lru_cache(__func: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ...

0 commit comments

Comments
 (0)