Skip to content

Commit d03a3b3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c502b02 commit d03a3b3

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

mypy/plugins/functools.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -400,27 +400,28 @@ def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
400400
"""Infer a more precise return type for functools.lru_cache decorator"""
401401
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
402402
return ctx.default_return_type
403-
403+
404404
# Only handle the simple case: @lru_cache (without parentheses)
405405
# where a function is passed directly as the first argument
406-
if (len(ctx.arg_types) >= 1 and
407-
len(ctx.arg_types[0]) == 1 and
408-
len(ctx.arg_types) <= 2): # Ensure we don't have extra args indicating parameterized call
409-
406+
if (
407+
len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types) <= 2
408+
): # Ensure we don't have extra args indicating parameterized call
409+
410410
first_arg_type = ctx.arg_types[0][0]
411-
411+
412412
# Explicitly check that this is NOT a literal or other non-function type
413-
from mypy.types import LiteralType, Instance
413+
from mypy.types import Instance, LiteralType
414+
414415
if isinstance(first_arg_type, (LiteralType, Instance)):
415416
# This is likely maxsize=128 or similar - let MyPy handle it
416417
return ctx.default_return_type
417-
418+
418419
# Try to extract callable type
419420
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
420421
if fn_type is not None:
421422
# This is the @lru_cache case (function passed directly)
422423
return fn_type
423-
424+
424425
# For all parameterized cases, don't interfere
425426
return ctx.default_return_type
426427

@@ -429,17 +430,17 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
429430
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
430431
if not isinstance(ctx.api, mypy.checker.TypeChecker):
431432
return ctx.default_return_type
432-
433+
433434
# Try to find the original function signature using AST/symbol table analysis
434435
original_signature = _find_original_function_signature(ctx)
435-
436+
436437
if original_signature is not None:
437438
# Validate the call against the original function signature
438439
actual_args = []
439440
actual_arg_kinds = []
440441
actual_arg_names = []
441442
seen_args = set()
442-
443+
443444
for i, param in enumerate(ctx.args):
444445
for j, a in enumerate(param):
445446
if a in seen_args:
@@ -458,55 +459,55 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
458459
context=ctx.context,
459460
)
460461
return result
461-
462+
462463
return ctx.default_return_type
463464

464465

465466
def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
466467
"""
467468
Attempt to find the original function signature from the call context.
468-
469+
469470
Returns the CallableType of the original function if found, None otherwise.
470471
This function safely traverses the AST structure to locate the original
471472
function signature that was decorated with @lru_cache.
472473
"""
473474
from mypy.nodes import CallExpr, Decorator, NameExpr
474-
475+
475476
# Ensure we have the required context structure
476477
if not isinstance(ctx.context, CallExpr):
477478
return None
478-
479+
479480
callee = ctx.context.callee
480481
if not isinstance(callee, NameExpr) or not callee.name:
481482
return None
482-
483+
483484
func_name = callee.name
484-
485+
485486
# Safely access the API globals
486-
if not hasattr(ctx.api, 'globals') or not isinstance(ctx.api.globals, dict):
487+
if not hasattr(ctx.api, "globals") or not isinstance(ctx.api.globals, dict):
487488
return None
488-
489+
489490
if func_name not in ctx.api.globals:
490491
return None
491-
492+
492493
symbol = ctx.api.globals[func_name]
493-
494+
494495
# Validate symbol structure before accessing node
495-
if not hasattr(symbol, 'node') or symbol.node is None:
496+
if not hasattr(symbol, "node") or symbol.node is None:
496497
return None
497-
498+
498499
# Check if this is a decorator node containing our function
499500
if isinstance(symbol.node, Decorator):
500501
decorator_node = symbol.node
501-
502+
502503
# Safely access the decorated function
503-
if not hasattr(decorator_node, 'func') or decorator_node.func is None:
504+
if not hasattr(decorator_node, "func") or decorator_node.func is None:
504505
return None
505-
506+
506507
func_def = decorator_node.func
507-
508+
508509
# Verify we have a callable type
509-
if hasattr(func_def, 'type') and isinstance(func_def.type, CallableType):
510+
if hasattr(func_def, "type") and isinstance(func_def.type, CallableType):
510511
return func_def.type
511-
512+
512513
return None

0 commit comments

Comments
 (0)