Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions codebase_rag/graph_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,62 +390,12 @@ def _collect_eligible_files(self) -> list[Path]:
eligible.append(filepath)
return eligible

def _should_force_full_reindex(
self, force: bool, old_hashes: FileHashCache
) -> bool:
if force or not old_hashes:
return False

fetch_all = getattr(self.ingestor, "fetch_all", None)
if not callable(fetch_all):
return False

try:
results = fetch_all(
(
"MATCH (n) "
"WHERE toString(n.qualified_name) STARTS WITH $prefix "
"RETURN count(n) AS c"
),
{"prefix": f"{self.project_name}."},
)
except Exception as e:
logger.debug(
"Incremental reindex graph-state probe failed for {name}: {error}",
name=self.project_name,
error=e,
)
return False

if not results:
logger.info(
"No graph-state probe results for {name}; forcing full reindex",
name=self.project_name,
)
return True

symbol_count = results[0].get("c", 0)
if not isinstance(symbol_count, int):
return False

if symbol_count == 0:
logger.info(
"No existing graph symbols found for {name}; ignoring hash cache and forcing full reindex",
name=self.project_name,
)
return True

return False

def _process_files(self, force: bool = False) -> None:
cache_path = self.repo_path / cs.HASH_CACHE_FILENAME
old_hashes = _load_hash_cache(cache_path) if not force else {}
if force:
logger.info(ls.INCREMENTAL_FORCE)

if self._should_force_full_reindex(force, old_hashes):
old_hashes = {}

eligible_files = self._collect_eligible_files()
new_hashes: FileHashCache = {}
skipped_count = 0
Expand Down
70 changes: 29 additions & 41 deletions codebase_rag/parsers/call_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

_SEPARATOR_PATTERN = re.compile(r"[.:]|::")
_CHAINED_METHOD_PATTERN = re.compile(r"\.([^.()]+)$")
_JS_INSTANCE_PREFIXES = {cs.KEYWORD_SELF, "this"}


class CallResolver:
Expand Down Expand Up @@ -53,15 +52,6 @@ def _try_resolve_method(
method_qn = f"{class_qn}{separator}{method_name}"
if method_qn in self.function_registry:
return self.function_registry[method_qn], method_qn

class_name = class_qn.split(cs.SEPARATOR_DOT)[-1]
suffix_matches = self.function_registry.find_ending_with(
f"{class_name}{separator}{method_name}"
)
if len(suffix_matches) == 1:
matched_qn = suffix_matches[0]
return self.function_registry[matched_qn], matched_qn

return self._resolve_inherited_method(class_qn, method_name)

def resolve_function_call(
Expand All @@ -81,17 +71,14 @@ def resolve_function_call(
return self._resolve_chained_call(call_name, module_qn, local_var_types)

if result := self._try_resolve_via_imports(
call_name, module_qn, local_var_types, class_context
call_name, module_qn, local_var_types
):
return result

if not self._has_separator(call_name):
if result := self._try_resolve_same_module(call_name, module_qn):
return result
return self._try_resolve_via_trie(call_name, module_qn)
if result := self._try_resolve_same_module(call_name, module_qn):
return result

logger.debug(ls.CALL_UNRESOLVED, call_name=call_name)
return None
return self._try_resolve_via_trie(call_name, module_qn)

def _try_resolve_iife(
self, call_name: str, module_qn: str
Expand Down Expand Up @@ -120,21 +107,21 @@ def _try_resolve_via_imports(
call_name: str,
module_qn: str,
local_var_types: dict[str, str] | None,
class_context: str | None = None,
) -> tuple[str, str] | None:
import_map = self.import_processor.import_mapping.get(module_qn, {})
if module_qn not in self.import_processor.import_mapping:
return None

import_map = self.import_processor.import_mapping[module_qn]

if result := self._try_resolve_direct_import(call_name, import_map):
return result

if result := self._try_resolve_qualified_call(
call_name, import_map, module_qn, local_var_types, class_context
call_name, import_map, module_qn, local_var_types
):
return result

if import_map:
return self._try_resolve_wildcard_imports(call_name, import_map)
return None
return self._try_resolve_wildcard_imports(call_name, import_map)

def _try_resolve_direct_import(
self, call_name: str, import_map: dict[str, str]
Expand All @@ -153,7 +140,6 @@ def _try_resolve_qualified_call(
import_map: dict[str, str],
module_qn: str,
local_var_types: dict[str, str] | None,
class_context: str | None = None,
) -> tuple[str, str] | None:
if not self._has_separator(call_name):
return None
Expand All @@ -163,17 +149,11 @@ def _try_resolve_qualified_call(

if len(parts) == 2:
if result := self._resolve_two_part_call(
parts,
call_name,
separator,
import_map,
module_qn,
local_var_types,
class_context,
parts, call_name, separator, import_map, module_qn, local_var_types
):
return result

if len(parts) >= 3 and parts[0] in _JS_INSTANCE_PREFIXES:
if len(parts) >= 3 and parts[0] == cs.KEYWORD_SELF:
return self._resolve_self_attribute_call(
parts, call_name, import_map, module_qn, local_var_types
)
Expand Down Expand Up @@ -255,14 +235,9 @@ def _resolve_two_part_call(
import_map: dict[str, str],
module_qn: str,
local_var_types: dict[str, str] | None,
class_context: str | None = None,
) -> tuple[str, str] | None:
object_name, method_name = parts

if object_name in _JS_INSTANCE_PREFIXES and class_context:
if result := self._try_resolve_method(class_context, method_name, separator):
return result

if result := self._try_resolve_via_local_type(
object_name,
method_name,
Expand All @@ -279,7 +254,7 @@ def _resolve_two_part_call(
):
return result

return None
return self._try_resolve_module_method(method_name, call_name, module_qn)

def _try_resolve_via_local_type(
self,
Expand Down Expand Up @@ -426,15 +401,28 @@ def _resolve_self_attribute_call(
if class_qn := self._resolve_class_qn_from_type(
var_type, import_map, module_qn
):
if resolved_method := self._try_resolve_method(class_qn, method_name):
method_qn = f"{class_qn}.{method_name}"
if method_qn in self.function_registry:
logger.debug(
ls.CALL_INSTANCE_ATTR,
call_name=call_name,
method_qn=resolved_method[1],
method_qn=method_qn,
attr_ref=attribute_ref,
var_type=var_type,
)
return resolved_method
return self.function_registry[method_qn], method_qn

if inherited_method := self._resolve_inherited_method(
class_qn, method_name
):
logger.debug(
ls.CALL_INSTANCE_ATTR_INHERITED,
call_name=call_name,
method_qn=inherited_method[1],
attr_ref=attribute_ref,
var_type=var_type,
)
return inherited_method

return None

Expand Down
139 changes: 3 additions & 136 deletions codebase_rag/parsers/js_ts/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,8 @@ def build_local_variable_type_map(
) -> dict[str, str]:
local_var_types: dict[str, str] = {}

if class_node := self._find_enclosing_class_node(caller_node):
self._collect_constructor_injected_types(
class_node, module_qn, local_var_types
)

self._collect_parameter_types(caller_node, module_qn, local_var_types)

stack: list[ASTNode] = [caller_node]

declarator_count = 0

while stack:
Expand All @@ -65,7 +59,7 @@ def build_local_variable_type_map(
)

if var_type := self._infer_js_variable_type_from_value(
value_node, module_qn, local_var_types
value_node, module_qn
):
local_var_types[var_name] = var_type
logger.debug(
Expand All @@ -85,138 +79,11 @@ def build_local_variable_type_map(
)
return local_var_types

def _find_enclosing_class_node(self, node: ASTNode) -> ASTNode | None:
current = node
while current is not None:
if current.type == cs.TS_CLASS_DECLARATION:
return current
current = current.parent
return None

def _collect_constructor_injected_types(
self,
class_node: ASTNode,
module_qn: str,
local_var_types: dict[str, str],
) -> None:
body_node = class_node.child_by_field_name(cs.FIELD_BODY)
if body_node is None:
return

for child in body_node.children:
if child.type != cs.TS_METHOD_DEFINITION:
continue

name_node = child.child_by_field_name(cs.FIELD_NAME)
if (
name_node is None
or name_node.text is None
or safe_decode_text(name_node) != cs.KEYWORD_CONSTRUCTOR
):
continue

params_node = child.child_by_field_name(cs.TS_FIELD_PARAMETERS)
if params_node is None:
return

for param in params_node.children:
self._collect_constructor_parameter_type(
param, module_qn, local_var_types
)
return

def _collect_constructor_parameter_type(
self,
param_node: ASTNode,
module_qn: str,
local_var_types: dict[str, str],
) -> None:
if param_node.type not in {
"required_parameter",
"optional_parameter",
cs.TS_FORMAL_PARAMETER,
}:
return

has_accessibility_modifier = any(
child.type == "accessibility_modifier" for child in param_node.children
)
if not has_accessibility_modifier:
return

param_name = self._extract_parameter_name(param_node)
if not param_name:
return

if not (param_type := self._extract_type_annotation_name(param_node)):
return

resolved_type = self._resolve_js_class_name(param_type, module_qn) or param_type
local_var_types[param_name] = resolved_type
local_var_types[f"this.{param_name}"] = resolved_type

def _collect_parameter_types(
self,
caller_node: ASTNode,
module_qn: str,
local_var_types: dict[str, str],
) -> None:
params_node = caller_node.child_by_field_name(cs.TS_FIELD_PARAMETERS)
if params_node is None:
return

for param in params_node.children:
if param.type not in {
"required_parameter",
"optional_parameter",
cs.TS_FORMAL_PARAMETER,
}:
continue

param_name = self._extract_parameter_name(param)
if not param_name or param_name in local_var_types:
continue

if not (param_type := self._extract_type_annotation_name(param)):
continue

resolved_type = self._resolve_js_class_name(param_type, module_qn) or param_type
local_var_types[param_name] = resolved_type

def _extract_parameter_name(self, param_node: ASTNode) -> str | None:
identifier_node = next(
(child for child in param_node.children if child.type == cs.TS_IDENTIFIER),
None,
)
return safe_decode_text(identifier_node) if identifier_node is not None else None

def _extract_type_annotation_name(self, node: ASTNode) -> str | None:
type_node = next(
(child for child in node.children if child.type == "type_annotation"),
None,
)
if type_node is None or type_node.text is None:
return None

type_text = safe_decode_text(type_node)
if not type_text:
return None

return type_text.lstrip(":").strip()

def _infer_js_variable_type_from_value(
self,
value_node: ASTNode,
module_qn: str,
local_var_types: dict[str, str],
self, value_node: ASTNode, module_qn: str
) -> str | None:
logger.debug(ls.JS_INFER_VALUE_NODE, node_type=value_node.type)

if value_node.type == cs.TS_MEMBER_EXPRESSION:
expr_text = safe_decode_text(value_node)
if expr_text and expr_text in local_var_types:
return local_var_types[expr_text]

if value_node.type == cs.TS_NEW_EXPRESSION:
if class_name := ut.extract_constructor_name(value_node):
class_qn = self._resolve_js_class_name(class_name, module_qn)
Expand Down
Loading
Loading