|
41 | 41 | _ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
|
42 | 42 |
|
43 | 43 | PARTIAL: Final = "functools.partial"
|
| 44 | +LRU_CACHE: Final = "functools.lru_cache" |
44 | 45 |
|
45 | 46 |
|
46 | 47 | class _MethodInfo(NamedTuple):
|
@@ -393,3 +394,119 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
|
393 | 394 | ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
|
394 | 395 |
|
395 | 396 | return result
|
| 397 | + |
| 398 | + |
| 399 | +def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type: |
| 400 | + """Infer a more precise return type for functools.lru_cache decorator""" |
| 401 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals |
| 402 | + return ctx.default_return_type |
| 403 | + |
| 404 | + # Only handle the simple case: @lru_cache (without parentheses) |
| 405 | + # where a function is passed directly as the first argument |
| 406 | + if (len(ctx.arg_types) >= 1 and |
| 407 | + len(ctx.arg_types[0]) == 1 and |
| 408 | + len(ctx.arg_types) <= 2): # Ensure we don't have extra args indicating parameterized call |
| 409 | + |
| 410 | + first_arg_type = ctx.arg_types[0][0] |
| 411 | + |
| 412 | + # Explicitly check that this is NOT a literal or other non-function type |
| 413 | + from mypy.types import LiteralType, Instance |
| 414 | + if isinstance(first_arg_type, (LiteralType, Instance)): |
| 415 | + # This is likely maxsize=128 or similar - let MyPy handle it |
| 416 | + return ctx.default_return_type |
| 417 | + |
| 418 | + # Try to extract callable type |
| 419 | + fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type) |
| 420 | + if fn_type is not None: |
| 421 | + # This is the @lru_cache case (function passed directly) |
| 422 | + return fn_type |
| 423 | + |
| 424 | + # For all parameterized cases, don't interfere |
| 425 | + return ctx.default_return_type |
| 426 | + |
| 427 | + |
| 428 | +def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type: |
| 429 | + """Handle calls to functools._lru_cache_wrapper objects to provide parameter validation""" |
| 430 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): |
| 431 | + return ctx.default_return_type |
| 432 | + |
| 433 | + # Try to find the original function signature using AST/symbol table analysis |
| 434 | + original_signature = _find_original_function_signature(ctx) |
| 435 | + |
| 436 | + if original_signature is not None: |
| 437 | + # Validate the call against the original function signature |
| 438 | + actual_args = [] |
| 439 | + actual_arg_kinds = [] |
| 440 | + actual_arg_names = [] |
| 441 | + seen_args = set() |
| 442 | + |
| 443 | + for i, param in enumerate(ctx.args): |
| 444 | + for j, a in enumerate(param): |
| 445 | + if a in seen_args: |
| 446 | + continue |
| 447 | + seen_args.add(a) |
| 448 | + actual_args.append(a) |
| 449 | + actual_arg_kinds.append(ctx.arg_kinds[i][j]) |
| 450 | + actual_arg_names.append(ctx.arg_names[i][j]) |
| 451 | + |
| 452 | + # Check the call against the original signature |
| 453 | + result, _ = ctx.api.expr_checker.check_call( |
| 454 | + callee=original_signature, |
| 455 | + args=actual_args, |
| 456 | + arg_kinds=actual_arg_kinds, |
| 457 | + arg_names=actual_arg_names, |
| 458 | + context=ctx.context, |
| 459 | + ) |
| 460 | + return result |
| 461 | + |
| 462 | + return ctx.default_return_type |
| 463 | + |
| 464 | + |
| 465 | +def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None: |
| 466 | + """ |
| 467 | + Attempt to find the original function signature from the call context. |
| 468 | + |
| 469 | + Returns the CallableType of the original function if found, None otherwise. |
| 470 | + This function safely traverses the AST structure to locate the original |
| 471 | + function signature that was decorated with @lru_cache. |
| 472 | + """ |
| 473 | + from mypy.nodes import CallExpr, Decorator, NameExpr |
| 474 | + |
| 475 | + # Ensure we have the required context structure |
| 476 | + if not isinstance(ctx.context, CallExpr): |
| 477 | + return None |
| 478 | + |
| 479 | + callee = ctx.context.callee |
| 480 | + if not isinstance(callee, NameExpr) or not callee.name: |
| 481 | + return None |
| 482 | + |
| 483 | + func_name = callee.name |
| 484 | + |
| 485 | + # Safely access the API globals |
| 486 | + if not hasattr(ctx.api, 'globals') or not isinstance(ctx.api.globals, dict): |
| 487 | + return None |
| 488 | + |
| 489 | + if func_name not in ctx.api.globals: |
| 490 | + return None |
| 491 | + |
| 492 | + symbol = ctx.api.globals[func_name] |
| 493 | + |
| 494 | + # Validate symbol structure before accessing node |
| 495 | + if not hasattr(symbol, 'node') or symbol.node is None: |
| 496 | + return None |
| 497 | + |
| 498 | + # Check if this is a decorator node containing our function |
| 499 | + if isinstance(symbol.node, Decorator): |
| 500 | + decorator_node = symbol.node |
| 501 | + |
| 502 | + # Safely access the decorated function |
| 503 | + if not hasattr(decorator_node, 'func') or decorator_node.func is None: |
| 504 | + return None |
| 505 | + |
| 506 | + func_def = decorator_node.func |
| 507 | + |
| 508 | + # Verify we have a callable type |
| 509 | + if hasattr(func_def, 'type') and isinstance(func_def.type, CallableType): |
| 510 | + return func_def.type |
| 511 | + |
| 512 | + return None |
0 commit comments