diff --git a/codebase_rag/graph_updater.py b/codebase_rag/graph_updater.py index 6a7eacbaa..86a41aa01 100644 --- a/codebase_rag/graph_updater.py +++ b/codebase_rag/graph_updater.py @@ -387,12 +387,62 @@ def _collect_eligible_files(self) -> list[Path]: eligible.append(filepath) return eligible + def _should_force_full_reindex( + self, force: bool, old_hashes: FileHashCache + ) -> bool: + if force or not old_hashes: + return False + + fetch_all = getattr(self.ingestor, "fetch_all", None) + if not callable(fetch_all): + return False + + try: + results = fetch_all( + ( + "MATCH (n) " + "WHERE toString(n.qualified_name) STARTS WITH $prefix " + "RETURN count(n) AS c" + ), + {"prefix": f"{self.project_name}."}, + ) + except Exception as e: + logger.debug( + "Incremental reindex graph-state probe failed for {name}: {error}", + name=self.project_name, + error=e, + ) + return False + + if not results: + logger.info( + "No graph-state probe results for {name}; forcing full reindex", + name=self.project_name, + ) + return True + + symbol_count = results[0].get("c", 0) + if not isinstance(symbol_count, int): + return False + + if symbol_count == 0: + logger.info( + "No existing graph symbols found for {name}; ignoring hash cache and forcing full reindex", + name=self.project_name, + ) + return True + + return False + def _process_files(self, force: bool = False) -> None: cache_path = self.repo_path / cs.HASH_CACHE_FILENAME old_hashes = _load_hash_cache(cache_path) if not force else {} if force: logger.info(ls.INCREMENTAL_FORCE) + if self._should_force_full_reindex(force, old_hashes): + old_hashes = {} + eligible_files = self._collect_eligible_files() new_hashes: FileHashCache = {} skipped_count = 0 diff --git a/codebase_rag/parsers/call_resolver.py b/codebase_rag/parsers/call_resolver.py index 993647759..eac60f58b 100644 --- a/codebase_rag/parsers/call_resolver.py +++ b/codebase_rag/parsers/call_resolver.py @@ -15,6 +15,7 @@ _SEPARATOR_PATTERN = re.compile(r"[.:]|::") _CHAINED_METHOD_PATTERN = re.compile(r"\.([^.()]+)$") +_JS_INSTANCE_PREFIXES = {cs.KEYWORD_SELF, "this"} class CallResolver: @@ -52,6 +53,15 @@ def _try_resolve_method( method_qn = f"{class_qn}{separator}{method_name}" if method_qn in self.function_registry: return self.function_registry[method_qn], method_qn + + class_name = class_qn.split(cs.SEPARATOR_DOT)[-1] + suffix_matches = self.function_registry.find_ending_with( + f"{class_name}{separator}{method_name}" + ) + if len(suffix_matches) == 1: + matched_qn = suffix_matches[0] + return self.function_registry[matched_qn], matched_qn + return self._resolve_inherited_method(class_qn, method_name) def resolve_function_call( @@ -71,14 +81,17 @@ def resolve_function_call( return self._resolve_chained_call(call_name, module_qn, local_var_types) if result := self._try_resolve_via_imports( - call_name, module_qn, local_var_types + call_name, module_qn, local_var_types, class_context ): return result - if result := self._try_resolve_same_module(call_name, module_qn): - return result + if not self._has_separator(call_name): + if result := self._try_resolve_same_module(call_name, module_qn): + return result + return self._try_resolve_via_trie(call_name, module_qn) - return self._try_resolve_via_trie(call_name, module_qn) + logger.debug(ls.CALL_UNRESOLVED, call_name=call_name) + return None def _try_resolve_iife( self, call_name: str, module_qn: str @@ -107,21 +120,21 @@ def _try_resolve_via_imports( call_name: str, module_qn: str, local_var_types: dict[str, str] | None, + class_context: str | None = None, ) -> tuple[str, str] | None: - if module_qn not in self.import_processor.import_mapping: - return None - - import_map = self.import_processor.import_mapping[module_qn] + import_map = self.import_processor.import_mapping.get(module_qn, {}) if result := self._try_resolve_direct_import(call_name, import_map): return result if result := self._try_resolve_qualified_call( - call_name, import_map, module_qn, local_var_types + call_name, import_map, module_qn, local_var_types, class_context ): return result - return self._try_resolve_wildcard_imports(call_name, import_map) + if import_map: + return self._try_resolve_wildcard_imports(call_name, import_map) + return None def _try_resolve_direct_import( self, call_name: str, import_map: dict[str, str] @@ -140,6 +153,7 @@ def _try_resolve_qualified_call( import_map: dict[str, str], module_qn: str, local_var_types: dict[str, str] | None, + class_context: str | None = None, ) -> tuple[str, str] | None: if not self._has_separator(call_name): return None @@ -149,11 +163,17 @@ def _try_resolve_qualified_call( if len(parts) == 2: if result := self._resolve_two_part_call( - parts, call_name, separator, import_map, module_qn, local_var_types + parts, + call_name, + separator, + import_map, + module_qn, + local_var_types, + class_context, ): return result - if len(parts) >= 3 and parts[0] == cs.KEYWORD_SELF: + if len(parts) >= 3 and parts[0] in _JS_INSTANCE_PREFIXES: return self._resolve_self_attribute_call( parts, call_name, import_map, module_qn, local_var_types ) @@ -235,9 +255,14 @@ def _resolve_two_part_call( import_map: dict[str, str], module_qn: str, local_var_types: dict[str, str] | None, + class_context: str | None = None, ) -> tuple[str, str] | None: object_name, method_name = parts + if object_name in _JS_INSTANCE_PREFIXES and class_context: + if result := self._try_resolve_method(class_context, method_name, separator): + return result + if result := self._try_resolve_via_local_type( object_name, method_name, @@ -254,7 +279,7 @@ def _resolve_two_part_call( ): return result - return self._try_resolve_module_method(method_name, call_name, module_qn) + return None def _try_resolve_via_local_type( self, @@ -401,28 +426,15 @@ def _resolve_self_attribute_call( if class_qn := self._resolve_class_qn_from_type( var_type, import_map, module_qn ): - method_qn = f"{class_qn}.{method_name}" - if method_qn in self.function_registry: + if resolved_method := self._try_resolve_method(class_qn, method_name): logger.debug( ls.CALL_INSTANCE_ATTR, call_name=call_name, - method_qn=method_qn, + method_qn=resolved_method[1], attr_ref=attribute_ref, var_type=var_type, ) - return self.function_registry[method_qn], method_qn - - if inherited_method := self._resolve_inherited_method( - class_qn, method_name - ): - logger.debug( - ls.CALL_INSTANCE_ATTR_INHERITED, - call_name=call_name, - method_qn=inherited_method[1], - attr_ref=attribute_ref, - var_type=var_type, - ) - return inherited_method + return resolved_method return None diff --git a/codebase_rag/parsers/js_ts/type_inference.py b/codebase_rag/parsers/js_ts/type_inference.py index 29a435c77..d471a93c9 100644 --- a/codebase_rag/parsers/js_ts/type_inference.py +++ b/codebase_rag/parsers/js_ts/type_inference.py @@ -35,8 +35,14 @@ def build_local_variable_type_map( ) -> dict[str, str]: local_var_types: dict[str, str] = {} - stack: list[ASTNode] = [caller_node] + if class_node := self._find_enclosing_class_node(caller_node): + self._collect_constructor_injected_types( + class_node, module_qn, local_var_types + ) + + self._collect_parameter_types(caller_node, module_qn, local_var_types) + stack: list[ASTNode] = [caller_node] declarator_count = 0 while stack: @@ -59,7 +65,7 @@ def build_local_variable_type_map( ) if var_type := self._infer_js_variable_type_from_value( - value_node, module_qn + value_node, module_qn, local_var_types ): local_var_types[var_name] = var_type logger.debug( @@ -79,11 +85,138 @@ def build_local_variable_type_map( ) return local_var_types + def _find_enclosing_class_node(self, node: ASTNode) -> ASTNode | None: + current = node + while current is not None: + if current.type == cs.TS_CLASS_DECLARATION: + return current + current = current.parent + return None + + def _collect_constructor_injected_types( + self, + class_node: ASTNode, + module_qn: str, + local_var_types: dict[str, str], + ) -> None: + body_node = class_node.child_by_field_name(cs.FIELD_BODY) + if body_node is None: + return + + for child in body_node.children: + if child.type != cs.TS_METHOD_DEFINITION: + continue + + name_node = child.child_by_field_name(cs.FIELD_NAME) + if ( + name_node is None + or name_node.text is None + or safe_decode_text(name_node) != cs.KEYWORD_CONSTRUCTOR + ): + continue + + params_node = child.child_by_field_name(cs.TS_FIELD_PARAMETERS) + if params_node is None: + return + + for param in params_node.children: + self._collect_constructor_parameter_type( + param, module_qn, local_var_types + ) + return + + def _collect_constructor_parameter_type( + self, + param_node: ASTNode, + module_qn: str, + local_var_types: dict[str, str], + ) -> None: + if param_node.type not in { + "required_parameter", + "optional_parameter", + cs.TS_FORMAL_PARAMETER, + }: + return + + has_accessibility_modifier = any( + child.type == "accessibility_modifier" for child in param_node.children + ) + if not has_accessibility_modifier: + return + + param_name = self._extract_parameter_name(param_node) + if not param_name: + return + + if not (param_type := self._extract_type_annotation_name(param_node)): + return + + resolved_type = self._resolve_js_class_name(param_type, module_qn) or param_type + local_var_types[param_name] = resolved_type + local_var_types[f"this.{param_name}"] = resolved_type + + def _collect_parameter_types( + self, + caller_node: ASTNode, + module_qn: str, + local_var_types: dict[str, str], + ) -> None: + params_node = caller_node.child_by_field_name(cs.TS_FIELD_PARAMETERS) + if params_node is None: + return + + for param in params_node.children: + if param.type not in { + "required_parameter", + "optional_parameter", + cs.TS_FORMAL_PARAMETER, + }: + continue + + param_name = self._extract_parameter_name(param) + if not param_name or param_name in local_var_types: + continue + + if not (param_type := self._extract_type_annotation_name(param)): + continue + + resolved_type = self._resolve_js_class_name(param_type, module_qn) or param_type + local_var_types[param_name] = resolved_type + + def _extract_parameter_name(self, param_node: ASTNode) -> str | None: + identifier_node = next( + (child for child in param_node.children if child.type == cs.TS_IDENTIFIER), + None, + ) + return safe_decode_text(identifier_node) if identifier_node is not None else None + + def _extract_type_annotation_name(self, node: ASTNode) -> str | None: + type_node = next( + (child for child in node.children if child.type == "type_annotation"), + None, + ) + if type_node is None or type_node.text is None: + return None + + type_text = safe_decode_text(type_node) + if not type_text: + return None + + return type_text.lstrip(":").strip() + def _infer_js_variable_type_from_value( - self, value_node: ASTNode, module_qn: str + self, + value_node: ASTNode, + module_qn: str, + local_var_types: dict[str, str], ) -> str | None: logger.debug(ls.JS_INFER_VALUE_NODE, node_type=value_node.type) + if value_node.type == cs.TS_MEMBER_EXPRESSION: + expr_text = safe_decode_text(value_node) + if expr_text and expr_text in local_var_types: + return local_var_types[expr_text] + if value_node.type == cs.TS_NEW_EXPRESSION: if class_name := ut.extract_constructor_name(value_node): class_qn = self._resolve_js_class_name(class_name, module_qn) diff --git a/codebase_rag/tests/test_call_resolver.py b/codebase_rag/tests/test_call_resolver.py index 0a23ae636..7bad0bb6b 100644 --- a/codebase_rag/tests/test_call_resolver.py +++ b/codebase_rag/tests/test_call_resolver.py @@ -1112,3 +1112,62 @@ def test_matches_deeply_chained(self) -> None: match = _CHAINED_METHOD_PATTERN.search("a.b().c().final_method") assert match is not None assert match[1] == "final_method" + + +class TestJsTsMemberResolution: + def test_resolves_injected_service_member_call_from_local_var_types( + self, call_resolver: CallResolver + ) -> None: + call_resolver.function_registry[ + "proj.controllers.routes.RoutesController.saveRoute" + ] = NodeType.METHOD + call_resolver.function_registry[ + "proj.services.RouteHistoryService.saveRoute" + ] = NodeType.METHOD + + result = call_resolver.resolve_function_call( + "routeHistoryService.saveRoute", + "proj.controllers.routes", + local_var_types={"routeHistoryService": "proj.services.RouteHistoryService"}, + class_context="proj.controllers.routes.RoutesController", + ) + + assert result == ( + NodeType.METHOD, + "proj.services.RouteHistoryService.saveRoute", + ) + + def test_resolves_this_method_against_class_context( + self, call_resolver: CallResolver + ) -> None: + call_resolver.function_registry[ + "proj.controllers.routes.RoutesController.saveRoute" + ] = NodeType.METHOD + + result = call_resolver.resolve_function_call( + "this.saveRoute", + "proj.controllers.routes", + local_var_types={}, + class_context="proj.controllers.routes.RoutesController", + ) + + assert result == ( + NodeType.METHOD, + "proj.controllers.routes.RoutesController.saveRoute", + ) + + def test_does_not_guess_qualified_member_calls_via_trie_fallback( + self, call_resolver: CallResolver + ) -> None: + call_resolver.function_registry[ + "proj.controllers.routes.RoutesController.saveRoute" + ] = NodeType.METHOD + + result = call_resolver.resolve_function_call( + "routeHistoryService.saveRoute", + "proj.controllers.routes", + local_var_types={}, + class_context="proj.controllers.routes.RoutesController", + ) + + assert result is None diff --git a/codebase_rag/tests/test_graph_updater_incremental.py b/codebase_rag/tests/test_graph_updater_incremental.py index 1e0a16583..8547525f5 100644 --- a/codebase_rag/tests/test_graph_updater_incremental.py +++ b/codebase_rag/tests/test_graph_updater_incremental.py @@ -288,3 +288,32 @@ def test_bounded_ast_cache_has_slots(self) -> None: cache = BoundedASTCache() with pytest.raises(AttributeError): cache.nonexistent_attr = "value" # type: ignore[attr-defined] + + def test_empty_graph_ignores_hash_cache_and_reindexes_all_files( + self, py_project: Path, mock_ingestor: MagicMock + ) -> None: + parsers, queries = load_parsers() + updater = GraphUpdater( + ingestor=mock_ingestor, + repo_path=py_project, + parsers=parsers, + queries=queries, + ) + updater.run() + + mock_ingestor.reset_mock() + mock_ingestor.fetch_all.return_value = [{"c": 0}] + + updater2 = GraphUpdater( + ingestor=mock_ingestor, + repo_path=py_project, + parsers=parsers, + queries=queries, + ) + with patch.object( + updater2, "_process_single_file", wraps=updater2._process_single_file + ) as spy: + updater2.run() + processed_paths = {call.args[0] for call in spy.call_args_list} + assert py_project / "module_a.py" in processed_paths + assert py_project / "module_b.py" in processed_paths