@@ -400,27 +400,28 @@ def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
400
400
"""Infer a more precise return type for functools.lru_cache decorator"""
401
401
if not isinstance (ctx .api , mypy .checker .TypeChecker ): # use internals
402
402
return ctx .default_return_type
403
-
403
+
404
404
# Only handle the simple case: @lru_cache (without parentheses)
405
405
# where a function is passed directly as the first argument
406
- if (len ( ctx . arg_types ) >= 1 and
407
- len (ctx .arg_types [0 ]) == 1 and
408
- len ( ctx . arg_types ) <= 2 ): # Ensure we don't have extra args indicating parameterized call
409
-
406
+ if (
407
+ len (ctx .arg_types ) >= 1 and len ( ctx . arg_types [0 ]) == 1 and len ( ctx . arg_types ) <= 2
408
+ ): # Ensure we don't have extra args indicating parameterized call
409
+
410
410
first_arg_type = ctx .arg_types [0 ][0 ]
411
-
411
+
412
412
# Explicitly check that this is NOT a literal or other non-function type
413
- from mypy .types import LiteralType , Instance
413
+ from mypy .types import Instance , LiteralType
414
+
414
415
if isinstance (first_arg_type , (LiteralType , Instance )):
415
416
# This is likely maxsize=128 or similar - let MyPy handle it
416
417
return ctx .default_return_type
417
-
418
+
418
419
# Try to extract callable type
419
420
fn_type = ctx .api .extract_callable_type (first_arg_type , ctx = ctx .default_return_type )
420
421
if fn_type is not None :
421
422
# This is the @lru_cache case (function passed directly)
422
423
return fn_type
423
-
424
+
424
425
# For all parameterized cases, don't interfere
425
426
return ctx .default_return_type
426
427
@@ -429,17 +430,17 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
429
430
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
430
431
if not isinstance (ctx .api , mypy .checker .TypeChecker ):
431
432
return ctx .default_return_type
432
-
433
+
433
434
# Try to find the original function signature using AST/symbol table analysis
434
435
original_signature = _find_original_function_signature (ctx )
435
-
436
+
436
437
if original_signature is not None :
437
438
# Validate the call against the original function signature
438
439
actual_args = []
439
440
actual_arg_kinds = []
440
441
actual_arg_names = []
441
442
seen_args = set ()
442
-
443
+
443
444
for i , param in enumerate (ctx .args ):
444
445
for j , a in enumerate (param ):
445
446
if a in seen_args :
@@ -458,55 +459,55 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
458
459
context = ctx .context ,
459
460
)
460
461
return result
461
-
462
+
462
463
return ctx .default_return_type
463
464
464
465
465
466
def _find_original_function_signature (ctx : mypy .plugin .MethodContext ) -> CallableType | None :
466
467
"""
467
468
Attempt to find the original function signature from the call context.
468
-
469
+
469
470
Returns the CallableType of the original function if found, None otherwise.
470
471
This function safely traverses the AST structure to locate the original
471
472
function signature that was decorated with @lru_cache.
472
473
"""
473
474
from mypy .nodes import CallExpr , Decorator , NameExpr
474
-
475
+
475
476
# Ensure we have the required context structure
476
477
if not isinstance (ctx .context , CallExpr ):
477
478
return None
478
-
479
+
479
480
callee = ctx .context .callee
480
481
if not isinstance (callee , NameExpr ) or not callee .name :
481
482
return None
482
-
483
+
483
484
func_name = callee .name
484
-
485
+
485
486
# Safely access the API globals
486
- if not hasattr (ctx .api , ' globals' ) or not isinstance (ctx .api .globals , dict ):
487
+ if not hasattr (ctx .api , " globals" ) or not isinstance (ctx .api .globals , dict ):
487
488
return None
488
-
489
+
489
490
if func_name not in ctx .api .globals :
490
491
return None
491
-
492
+
492
493
symbol = ctx .api .globals [func_name ]
493
-
494
+
494
495
# Validate symbol structure before accessing node
495
- if not hasattr (symbol , ' node' ) or symbol .node is None :
496
+ if not hasattr (symbol , " node" ) or symbol .node is None :
496
497
return None
497
-
498
+
498
499
# Check if this is a decorator node containing our function
499
500
if isinstance (symbol .node , Decorator ):
500
501
decorator_node = symbol .node
501
-
502
+
502
503
# Safely access the decorated function
503
- if not hasattr (decorator_node , ' func' ) or decorator_node .func is None :
504
+ if not hasattr (decorator_node , " func" ) or decorator_node .func is None :
504
505
return None
505
-
506
+
506
507
func_def = decorator_node .func
507
-
508
+
508
509
# Verify we have a callable type
509
- if hasattr (func_def , ' type' ) and isinstance (func_def .type , CallableType ):
510
+ if hasattr (func_def , " type" ) and isinstance (func_def .type , CallableType ):
510
511
return func_def .type
511
-
512
+
512
513
return None
0 commit comments