Skip to content

Add functools.lru_cache plugin support Fixes #16261 #19432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@
)
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
from mypy.plugins.functools import (
functools_lru_cache_callback,
functools_total_ordering_maker_callback,
functools_total_ordering_makers,
lru_cache_wrapper_call_callback,
partial_call_callback,
partial_new_callback,
)
Expand Down Expand Up @@ -101,6 +103,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return create_singledispatch_function_callback
elif fullname == "functools.partial":
return partial_new_callback
elif fullname == "functools.lru_cache":
return functools_lru_cache_callback
elif fullname == "enum.member":
return enum_member_callback
return None
Expand Down Expand Up @@ -160,6 +164,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
return partial_call_callback
elif fullname == "functools._lru_cache_wrapper.__call__":
return lru_cache_wrapper_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down
133 changes: 133 additions & 0 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}

PARTIAL: Final = "functools.partial"
LRU_CACHE: Final = "functools.lru_cache"


class _MethodInfo(NamedTuple):
Expand Down Expand Up @@ -393,3 +394,135 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)

return result


def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.lru_cache decorator"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type

# Only handle the very specific case: @lru_cache (without parentheses)
# where a single function is passed directly as the only argument
if (
len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1
and len(ctx.args) == 1
and len(ctx.args[0]) == 1
):

first_arg_type = ctx.arg_types[0][0]

# Explicitly reject literal types, instances, and None
from mypy.types import Instance, LiteralType, NoneType

proper_first_arg_type = get_proper_type(first_arg_type)
if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)):
return ctx.default_return_type

# Try to extract callable type
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
if fn_type is not None:
# This is the @lru_cache case (function passed directly)
return fn_type

# For all other cases (parameterized, multiple args, etc.), don't interfere
return ctx.default_return_type


def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
if not isinstance(ctx.api, mypy.checker.TypeChecker):
return ctx.default_return_type

# Safety check: ensure we have the required context
if not ctx.context or not ctx.args or not ctx.arg_types:
return ctx.default_return_type

# Try to find the original function signature using AST/symbol table analysis
original_signature = _find_original_function_signature(ctx)

if original_signature is not None:
# Validate the call against the original function signature
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
seen_args = set()

for i, param in enumerate(ctx.args):
for j, a in enumerate(param):
if a in seen_args:
continue
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])

# Check the call against the original signature
try:
result, _ = ctx.api.expr_checker.check_call(
callee=original_signature,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result
except Exception:
# If check_call fails, fall back gracefully
pass

return ctx.default_return_type


def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
"""
Attempt to find the original function signature from the call context.

Returns the CallableType of the original function if found, None otherwise.
This function safely traverses the AST structure to locate the original
function signature that was decorated with @lru_cache.
"""
from mypy.nodes import CallExpr, Decorator, NameExpr

try:
# Ensure we have the required context structure
if not isinstance(ctx.context, CallExpr):
return None

callee = ctx.context.callee
if not isinstance(callee, NameExpr) or not callee.name:
return None

func_name = callee.name

# Safely access the API globals
if not hasattr(ctx.api, "globals") or not isinstance(ctx.api.globals, dict):
return None

if func_name not in ctx.api.globals:
return None

symbol = ctx.api.globals[func_name]

# Validate symbol structure before accessing node
if not hasattr(symbol, "node") or symbol.node is None:
return None

# Check if this is a decorator node containing our function
if isinstance(symbol.node, Decorator):
decorator_node = symbol.node

# Safely access the decorated function
if not hasattr(decorator_node, "func") or decorator_node.func is None:
return None

func_def = decorator_node.func

# Verify we have a callable type
if hasattr(func_def, "type") and isinstance(func_def.type, CallableType):
return func_def.type

return None
except (AttributeError, TypeError, KeyError):
# If anything goes wrong in AST traversal, fail gracefully
return None
132 changes: 132 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,135 @@ def outer_c(arg: Tc) -> None:
use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \
# N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]"
[builtins fixtures/tuple.pyi]

[case testLruCacheBasicValidation]
from functools import lru_cache

@lru_cache
def f(v: str, at: int) -> str:
return v

f() # E: Missing positional arguments "v", "at" in call to "f"
f("abc") # E: Missing positional argument "at" in call to "f"
f("abc", 123) # OK
f("abc", at=123) # OK
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]

[case testLruCacheWithReturnType]
from functools import lru_cache

@lru_cache
def multiply(x: int, y: int) -> int:
return 42

reveal_type(multiply) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
reveal_type(multiply(2, 3)) # N: Revealed type is "builtins.int"
multiply("a", 3) # E: Argument 1 to "multiply" has incompatible type "str"; expected "int"
multiply(2, "b") # E: Argument 2 to "multiply" has incompatible type "str"; expected "int"
multiply(2) # E: Missing positional argument "y" in call to "multiply"
multiply(1, 2, 3) # E: Too many arguments for "multiply"
[builtins fixtures/dict.pyi]

[case testLruCacheWithOptionalArgs]
from functools import lru_cache

@lru_cache
def greet(name: str, greeting: str = "Hello") -> str:
return "result"

greet("World") # OK
greet("World", "Hi") # OK
greet("World", greeting="Hi") # OK
greet() # E: Missing positional argument "name" in call to "greet"
greet(123) # E: Argument 1 to "greet" has incompatible type "int"; expected "str"
greet("World", 123) # E: Argument 2 to "greet" has incompatible type "int"; expected "str"
[builtins fixtures/dict.pyi]

[case testLruCacheGenericFunction]
from functools import lru_cache
from typing import TypeVar

T = TypeVar('T')

@lru_cache
def identity(x: T) -> T:
return x

reveal_type(identity(42)) # N: Revealed type is "builtins.int"
reveal_type(identity("hello")) # N: Revealed type is "builtins.str"
identity() # E: Missing positional argument "x" in call to "identity"
[builtins fixtures/dict.pyi]

[case testLruCacheWithParentheses]
from functools import lru_cache

@lru_cache()
def f(v: str, at: int) -> str:
return v

f() # E: Missing positional arguments "v", "at" in call to "f"
f("abc") # E: Missing positional argument "at" in call to "f"
f("abc", 123) # OK
f("abc", at=123) # OK
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]

[case testLruCacheWithMaxsize]
from functools import lru_cache

@lru_cache(maxsize=128)
def g(v: str, at: int) -> str:
return v

g() # E: Missing positional arguments "v", "at" in call to "g"
g("abc") # E: Missing positional argument "at" in call to "g"
g("abc", 123) # OK
g("abc", at=123) # OK
g("abc", at="wrong_type") # E: Argument "at" to "g" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]

[case testLruCacheGenericWithParameters]
from functools import lru_cache
from typing import TypeVar

T = TypeVar('T')

@lru_cache()
def identity_empty(x: T) -> T:
return x

@lru_cache(maxsize=128)
def identity_maxsize(x: T) -> T:
return x

reveal_type(identity_empty(42)) # N: Revealed type is "builtins.int"
reveal_type(identity_maxsize("hello")) # N: Revealed type is "builtins.str"
identity_empty() # E: Missing positional argument "x" in call to "identity_empty"
identity_maxsize() # E: Missing positional argument "x" in call to "identity_maxsize"
[builtins fixtures/dict.pyi]

[case testLruCacheMaxsizeNone]
from functools import lru_cache

@lru_cache(maxsize=None)
def unlimited_cache(x: int, y: str) -> str:
return y

unlimited_cache(42, "test") # OK
unlimited_cache() # E: Missing positional arguments "x", "y" in call to "unlimited_cache"
unlimited_cache(42) # E: Missing positional argument "y" in call to "unlimited_cache"
unlimited_cache("wrong", "test") # E: Argument 1 to "unlimited_cache" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]

[case testLruCacheMaxsizeZero]
from functools import lru_cache

@lru_cache(maxsize=0)
def no_cache(value: str) -> str:
return value

no_cache("hello") # OK
no_cache() # E: Missing positional argument "value" in call to "no_cache"
no_cache(123) # E: Argument 1 to "no_cache" has incompatible type "int"; expected "str"
[builtins fixtures/dict.pyi]
8 changes: 8 additions & 0 deletions test-data/unit/lib-stub/functools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,11 @@ class cached_property(Generic[_T]):
class partial(Generic[_T]):
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...

class _lru_cache_wrapper(Generic[_T]):
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...

@overload
def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ...
@overload
def lru_cache(__func: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ...