Skip to content

Commit 16f5914

Browse files
committed
fix(graphrag): align llm typing and mind map annotations
What problem does this solve? Enabling runtime type checking with beartype against the current RAGFlow GraphRAG code surfaced several latent mismatches between type annotations and actual runtime behavior. The failures fall into three groups: - GraphRAG components annotate their LLM dependency as the concrete chat model Base, but callers may pass compatible wrappers such as LLMBundle. - MindMapExtractor contains several narrow or incorrect annotations that do not match the actual structures returned by markdown_to_json.dictify(), especially list and string leaves. - general/index.py contains signatures that are stricter than real KB-level call sites, where extractor is passed as a class and doc_id can be None. Without these fixes, runtime type checking can cause upload_and_parse and related GraphRAG flows to fail before or during mind map / graph processing. Type of change - Bug fix - Type annotation correction - Test coverage improvement Why is this change recommended? These issues are latent in the current GraphRAG code and may remain unnoticed in default setups, but they represent real mismatches between declared contracts and actual runtime behavior. This change: - makes GraphRAG depend on the minimal async chat capability it actually uses; - aligns function signatures with real call patterns; - fixes MindMapExtractor annotations to match its real input and output shapes; - preserves list-based mind map leaves correctly during post-processing. How is it fixed? 1. Introduce rag.graphrag.llm_protocol.GraphRAGCompletionLLM as the minimal protocol GraphRAG actually needs: - llm_name - max_length - async_chat(...) 2. Replace concrete Base annotations with this protocol across GraphRAG components: - rag/graphrag/general/extractor.py - rag/graphrag/general/mind_map_extractor.py - rag/graphrag/general/community_reports_extractor.py - rag/graphrag/general/graph_extractor.py - rag/graphrag/light/graph_extractor.py - rag/graphrag/entity_resolution.py - rag/graphrag/search.py 3. Fix MindMapExtractor annotations and leaf handling: - widen _todict() to support Mapping | list | str - widen _be_children() to support Mapping | list[str] | str - change _process_document() return type from str to None - preserve plain string list leaves in _list_to_kv() instead of collapsing them into an empty dict 4. Normalize GraphRAG chat responses to support both return shapes currently present in the codebase. RAGFlow commit 67937a6 fixed one concrete extraction failure by indexing response[0], which addresses the tuple-returning path. However, LLMBundle.async_chat in the current codebase returns plain text, so GraphRAG can also receive a string depending on the caller. This change makes GraphRAG explicitly handle both tuple[str, int] and str responses instead of depending on only one path. 5. Relax general/index.py signatures to match real usage: - generate_subgraph(extractor: type[Extractor], ...) - resolve_entities(doc_id: str | None, ...) - extract_community(doc_id: str | None, ...) Validation - python3 -m py_compile rag/graphrag/general/index.py - python3 -m compileall rag/graphrag - PYTHONPATH=/Users/dxl/project/python/local/ragflow /Users/dxl/project/python/multirag/.venv/bin/python -m pytest test/unit_test/rag/graphrag/test_llm_protocol.py The targeted regression tests cover: - protocol-based LLM acceptance in MindMapExtractor - tuple-based chat responses remaining supported - _todict() preserving list leaves - _be_children() accepting list leaves - _process_document() returning None
1 parent 1399c60 commit 16f5914

File tree

10 files changed

+160
-34
lines changed

10 files changed

+160
-34
lines changed

rag/graphrag/entity_resolution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import networkx as nx
2525

2626
from rag.graphrag.general.extractor import Extractor
27+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM
2728
from rag.nlp import is_english
2829
import editdistance
2930
from rag.graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
30-
from rag.llm.chat_model import Base as CompletionLLM
3131
from rag.graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
3232
from api.db.services.task_service import has_canceled
3333
from common.exceptions import TaskCanceledException
@@ -57,7 +57,7 @@ class EntityResolution(Extractor):
5757

5858
def __init__(
5959
self,
60-
llm_invoker: CompletionLLM,
60+
llm_invoker: GraphRAGCompletionLLM,
6161
):
6262
super().__init__(llm_invoker)
6363
"""Init method definition."""
@@ -294,4 +294,3 @@ def is_similarity(self, a, b):
294294
return len(a & b) > 1
295295

296296
return len(a & b)*1./max_l >= 0.8
297-

rag/graphrag/general/community_reports_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from common.exceptions import TaskCanceledException
2323
from common.connection_utils import timeout
2424
from rag.graphrag.general import leiden
25+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM
2526
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
2627
from rag.graphrag.general.extractor import Extractor
2728
from rag.graphrag.general.leiden import add_community_info2graph
28-
from rag.llm.chat_model import Base as CompletionLLM
2929
from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
3030
from common.token_utils import num_tokens_from_string
3131

@@ -46,7 +46,7 @@ class CommunityReportsExtractor(Extractor):
4646

4747
def __init__(
4848
self,
49-
llm_invoker: CompletionLLM,
49+
llm_invoker: GraphRAGCompletionLLM,
5050
max_report_length: int | None = None,
5151
):
5252
super().__init__(llm_invoker)

rag/graphrag/general/extractor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from api.db.services.task_service import has_canceled
2727
from common.connection_utils import timeout
2828
from common.token_utils import truncate
29+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM, unwrap_graphrag_chat_response
2930
from rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
3031
from rag.graphrag.utils import (
3132
GraphChange,
@@ -39,7 +40,6 @@
3940
split_string_by_multi_markers,
4041
)
4142
from common.misc_utils import thread_pool_exec
42-
from rag.llm.chat_model import Base as CompletionLLM
4343
from rag.prompts.generator import message_fit_in
4444
from common.exceptions import TaskCanceledException
4545

@@ -50,11 +50,11 @@
5050

5151

5252
class Extractor:
53-
_llm: CompletionLLM
53+
_llm: GraphRAGCompletionLLM
5454

5555
def __init__(
5656
self,
57-
llm_invoker: CompletionLLM,
57+
llm_invoker: GraphRAGCompletionLLM,
5858
language: str | None = "English",
5959
entity_types: list[str] | None = None,
6060
):
@@ -78,7 +78,8 @@ def _chat(self, system, history, gen_conf={}, task_id=""):
7878
raise TaskCanceledException(f"Task {task_id} was cancelled")
7979
try:
8080
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
81-
response = re.sub(r"^.*</think>", "", response[0], flags=re.DOTALL)
81+
response = unwrap_graphrag_chat_response(response)
82+
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
8283
if response.find("**ERROR**") >= 0:
8384
raise Exception(response)
8485
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)

rag/graphrag/general/graph_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import tiktoken
1515

1616
from rag.graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
17+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM
1718
from rag.graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
1819
from rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers
19-
from rag.llm.chat_model import Base as CompletionLLM
2020
import networkx as nx
2121
from common.token_utils import num_tokens_from_string
2222

@@ -52,7 +52,7 @@ class GraphExtractor(Extractor):
5252

5353
def __init__(
5454
self,
55-
llm_invoker: CompletionLLM,
55+
llm_invoker: GraphRAGCompletionLLM,
5656
language: str | None = "English",
5757
entity_types: list[str] | None = None,
5858
tuple_delimiter_key: str | None = None,

rag/graphrag/general/index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ async def build_one(doc_id: str):
393393

394394

395395
async def generate_subgraph(
396-
extractor: Extractor,
396+
extractor: type[Extractor],
397397
tenant_id: str,
398398
kb_id: str,
399399
doc_id: str,
@@ -504,7 +504,7 @@ async def resolve_entities(
504504
subgraph_nodes: set[str],
505505
tenant_id: str,
506506
kb_id: str,
507-
doc_id: str,
507+
doc_id: str | None,
508508
llm_bdl,
509509
embed_bdl,
510510
callback,
@@ -539,7 +539,7 @@ async def extract_community(
539539
graph,
540540
tenant_id: str,
541541
kb_id: str,
542-
doc_id: str,
542+
doc_id: str | None,
543543
llm_bdl,
544544
embed_bdl,
545545
callback,

rag/graphrag/general/mind_map_extractor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import logging
1919
import collections
2020
import re
21+
from collections.abc import Mapping
2122
from typing import Any
2223
from dataclasses import dataclass
2324

2425
from rag.graphrag.general.extractor import Extractor
26+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM
2527
from rag.graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
2628
from rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
27-
from rag.llm.chat_model import Base as CompletionLLM
2829
import markdown_to_json
2930
from functools import reduce
3031
from common.token_utils import num_tokens_from_string
@@ -44,7 +45,7 @@ class MindMapExtractor(Extractor):
4445

4546
def __init__(
4647
self,
47-
llm_invoker: CompletionLLM,
48+
llm_invoker: GraphRAGCompletionLLM,
4849
prompt: str | None = None,
4950
input_text_key: str | None = None,
5051
on_error: ErrorHandlerFn | None = None,
@@ -59,7 +60,7 @@ def __init__(
5960
def _key(self, k):
6061
return re.sub(r"\*+", "", k)
6162

62-
def _be_children(self, obj: dict, keyset: set):
63+
def _be_children(self, obj: Mapping[str, Any] | list[str] | str, keyset: set[str]) -> list[dict[str, Any]]:
6364
if isinstance(obj, str):
6465
obj = [obj]
6566
if isinstance(obj, list):
@@ -150,36 +151,35 @@ def _merge(self, d1, d2):
150151

151152
return d2
152153

153-
def _list_to_kv(self, data):
154+
def _list_to_kv(self, data: dict[str, Any]) -> dict[str, Any]:
154155
for key, value in data.items():
155156
if isinstance(value, dict):
156157
self._list_to_kv(value)
157158
elif isinstance(value, list):
158159
new_value = {}
160+
has_nested_list = False
159161
for i in range(len(value)):
160162
if isinstance(value[i], list) and i > 0:
163+
has_nested_list = True
161164
new_value[value[i - 1]] = value[i][0]
162-
data[key] = new_value
165+
data[key] = new_value if has_nested_list else value
163166
else:
164167
continue
165168
return data
166169

167-
def _todict(self, layer: collections.OrderedDict):
168-
to_ret = layer
169-
if isinstance(layer, collections.OrderedDict):
170+
def _todict(self, layer: Mapping[str, Any] | list[Any] | str) -> dict[str, Any] | list[Any] | str:
171+
if isinstance(layer, collections.OrderedDict | dict):
170172
to_ret = dict(layer)
171-
172-
try:
173173
for key, value in to_ret.items():
174174
to_ret[key] = self._todict(value)
175-
except AttributeError:
176-
pass
177-
178-
return self._list_to_kv(to_ret)
175+
return self._list_to_kv(to_ret)
176+
if isinstance(layer, list):
177+
return [self._todict(value) for value in layer]
178+
return layer
179179

180180
async def _process_document(
181-
self, text: str, prompt_variables: dict[str, str], out_res
182-
) -> str:
181+
self, text: str, prompt_variables: dict[str, str], out_res: list[dict[str, Any] | list[Any] | str]
182+
) -> None:
183183
variables = {
184184
**prompt_variables,
185185
self._input_text_key: text,

rag/graphrag/light/graph_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import networkx as nx
1717

1818
from rag.graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
19+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM
1920
from rag.graphrag.light.graph_prompt import PROMPTS
2021
from rag.graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split_string_by_multi_markers
21-
from rag.llm.chat_model import Base as CompletionLLM
2222
from common.token_utils import num_tokens_from_string
2323

2424
@dataclass
@@ -34,7 +34,7 @@ class GraphExtractor(Extractor):
3434

3535
def __init__(
3636
self,
37-
llm_invoker: CompletionLLM,
37+
llm_invoker: GraphRAGCompletionLLM,
3838
language: str | None = "English",
3939
entity_types: list[str] | None = None,
4040
example_number: int = 2,

rag/graphrag/llm_protocol.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Any, Protocol, TypeAlias, runtime_checkable
2+
3+
GraphRAGChatResponse: TypeAlias = str | tuple[str, int]
4+
5+
6+
def unwrap_graphrag_chat_response(response: GraphRAGChatResponse) -> str:
7+
if isinstance(response, tuple):
8+
return response[0]
9+
return response
10+
11+
12+
@runtime_checkable
13+
class GraphRAGCompletionLLM(Protocol):
14+
"""Minimal async chat contract used across GraphRAG components."""
15+
16+
llm_name: str
17+
max_length: int
18+
19+
async def async_chat(
20+
self,
21+
system: str,
22+
history: list[dict[str, Any]],
23+
gen_conf: dict[str, Any] | None = None,
24+
**kwargs,
25+
) -> GraphRAGChatResponse: ...

rag/graphrag/search.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pandas as pd
2323

2424
from common.misc_utils import get_uuid
25+
from rag.graphrag.llm_protocol import GraphRAGCompletionLLM, unwrap_graphrag_chat_response
2526
from rag.graphrag.query_analyze_prompt import PROMPTS
2627
from rag.graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
2728
from common.token_utils import num_tokens_from_string
@@ -33,17 +34,18 @@
3334

3435

3536
class KGSearch(Dealer):
36-
async def _chat(self, llm_bdl, system, history, gen_conf):
37+
async def _chat(self, llm_bdl: GraphRAGCompletionLLM, system: str, history: list[dict[str, str]], gen_conf: dict):
3738
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
3839
if response:
3940
return response
4041
response = await llm_bdl.async_chat(system, history, gen_conf)
42+
response = unwrap_graphrag_chat_response(response)
4143
if response.find("**ERROR**") >= 0:
4244
raise Exception(response)
4345
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
4446
return response
4547

46-
async def query_rewrite(self, llm, question, idxnms, kb_ids):
48+
async def query_rewrite(self, llm: GraphRAGCompletionLLM, question, idxnms, kb_ids):
4749
ty2ents = await get_entity_type2samples(idxnms, kb_ids)
4850
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
4951
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import asyncio
2+
import collections
3+
import sys
4+
import types
5+
6+
api_module = types.ModuleType("api")
7+
api_module.__path__ = []
8+
db_module = types.ModuleType("api.db")
9+
db_module.__path__ = []
10+
services_module = types.ModuleType("api.db.services")
11+
services_module.__path__ = []
12+
task_service_module = types.ModuleType("api.db.services.task_service")
13+
task_service_module.has_canceled = lambda *_args, **_kwargs: False
14+
15+
api_module.db = db_module
16+
db_module.services = services_module
17+
services_module.task_service = task_service_module
18+
19+
sys.modules.setdefault("api", api_module)
20+
sys.modules.setdefault("api.db", db_module)
21+
sys.modules.setdefault("api.db.services", services_module)
22+
sys.modules.setdefault("api.db.services.task_service", task_service_module)
23+
24+
import rag.graphrag.general.extractor as extractor_module
25+
import rag.graphrag.general.mind_map_extractor as mind_map_extractor_module
26+
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
27+
28+
29+
class FakeLLM:
30+
llm_name = "fake-llm"
31+
max_length = 4096
32+
33+
async def async_chat(self, system, history: list[dict[str, str]], gen_conf=None, **kwargs):
34+
return "{}"
35+
36+
37+
class TupleLLM:
38+
llm_name = "tuple-llm"
39+
max_length = 4096
40+
41+
async def async_chat(self, system, history: list[dict[str, str]], gen_conf=None, **kwargs):
42+
return "{}", 0
43+
44+
45+
def test_mind_map_extractor_accepts_protocol_based_llm():
46+
extractor = MindMapExtractor(FakeLLM())
47+
48+
assert extractor._llm.llm_name == "fake-llm"
49+
assert extractor._llm.max_length == 4096
50+
51+
52+
def test_mind_map_extractor_accepts_tuple_chat_response(monkeypatch):
53+
extractor = MindMapExtractor(TupleLLM())
54+
monkeypatch.setattr(extractor_module, "get_llm_cache", lambda *args, **kwargs: None)
55+
monkeypatch.setattr(extractor_module, "set_llm_cache", lambda *args, **kwargs: None)
56+
57+
assert extractor._chat("system", [{"role": "user", "content": "Output:"}], {}) == "{}"
58+
59+
60+
def test_mind_map_extractor_todict_supports_list_leaves():
61+
extractor = MindMapExtractor(FakeLLM())
62+
layer = collections.OrderedDict(
63+
{
64+
"顶层": collections.OrderedDict(
65+
{
66+
"部分A": [
67+
"点1",
68+
"点2",
69+
]
70+
}
71+
)
72+
}
73+
)
74+
75+
assert extractor._todict(layer) == {"顶层": {"部分A": ["点1", "点2"]}}
76+
77+
78+
def test_mind_map_extractor_be_children_supports_list_leaves():
79+
extractor = MindMapExtractor(FakeLLM())
80+
81+
assert extractor._be_children(["点1", "点2"], {"顶层"}) == [
82+
{"id": "点1", "children": []},
83+
{"id": "点2", "children": []},
84+
]
85+
86+
87+
def test_mind_map_extractor_process_document_returns_none(monkeypatch):
88+
extractor = MindMapExtractor(FakeLLM())
89+
out_res = []
90+
91+
async def fake_thread_pool_exec(*args, **kwargs):
92+
return "# 顶层\n## 部分A\n- 点1\n- 点2"
93+
94+
monkeypatch.setattr(mind_map_extractor_module, "thread_pool_exec", fake_thread_pool_exec)
95+
96+
result = asyncio.run(extractor._process_document("课堂纪要", {}, out_res))
97+
98+
assert result is None
99+
assert out_res == [{"顶层": {"部分A": ["点1", "点2"]}}]

0 commit comments

Comments
 (0)