Skip to content

Commit 2d3dafa

Browse files
committed
feat: support eval in dynamic pipeline
Signed-off-by: Keming <kemingyang@tensorchord.ai>
1 parent 0f0de75 commit 2d3dafa

File tree

5 files changed

+77
-17
lines changed

5 files changed

+77
-17
lines changed

examples/dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def main():
6767
dir = Path.home() / "Pictures"
6868
await ingest(dir.glob("*.jpg"))
6969
res = await search("cat")
70-
for item in res:
70+
for item in res.chunks:
7171
print("=>", file_uuids.get(item.doc_id, "Unknown file"))
7272

7373

vechord/evaluate.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import msgspec
66
import pytrec_eval
77

8-
from vechord.errors import DecodeStructuredOutputError
8+
from vechord.errors import DecodeStructuredOutputError, RequestError
99
from vechord.model import GeminiGenerateRequest, RetrievedChunk, UMBRELAScore
1010
from vechord.provider import GeminiGenerateProvider
1111

@@ -115,10 +115,15 @@ class GeminiUMBRELAEvaluator(BaseEvaluator, GeminiGenerateProvider):
115115
- paper: https://arxiv.org/pdf/2406.06519
116116
"""
117117

118-
def __init__(self, model: str = "gemini-2.5-flash", relevant_threshold: int = 2):
118+
def __init__(
119+
self,
120+
model: str = "gemini-2.5-flash",
121+
relevant_threshold: int = 2,
122+
k_values: Sequence[int] = (3, 5, 10),
123+
):
119124
super().__init__(model)
120125
self.relevant_threshold = relevant_threshold
121-
self.k_values = (3, 5, 10)
126+
self.k_values = k_values
122127
self.score_schema = msgspec.json.schema(UMBRELAScore)
123128
self.prompt = """
124129
Given a query and a passage, you must provide a score on an
@@ -153,6 +158,8 @@ def name(self) -> str:
153158
return f"gemini_umbrela_{self.model}"
154159

155160
async def estimate(self, query: str, passage: str) -> int:
161+
if not passage:
162+
return 0
156163
content = self.prompt.format(query=query, passage=passage)
157164
resp = await self.query(
158165
GeminiGenerateRequest.from_prompt_structure_response(
@@ -171,6 +178,8 @@ async def evaluate_with_estimation(
171178
self, query: str, passages: list[str]
172179
) -> dict[str, float]:
173180
"""Calculate the Precision@K and Mean Reciprocal Rank (MRR)."""
181+
if not query or not passages or not all(not p.strip() for p in passages):
182+
raise RequestError("Query and passages must be non-empty strings.")
174183
scores = [await self.estimate(query, p) for p in passages]
175184
is_relevant = [score >= self.relevant_threshold for score in scores]
176185
metric = defaultdict(float)

vechord/model/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
VoyageEmbeddingResponse,
2525
VoyageMultiModalEmbeddingRequest,
2626
)
27-
from vechord.model.web import InputType, ResourceRequest, RunAck, RunRequest
27+
from vechord.model.web import (
28+
InputType,
29+
ResourceRequest,
30+
RunAck,
31+
RunRequest,
32+
RunResponse,
33+
)
2834

2935
__all__ = [
3036
"Document",
@@ -45,6 +51,7 @@
4551
"RetrievedChunk",
4652
"RunAck",
4753
"RunRequest",
54+
"RunResponse",
4855
"SparseEmbedding",
4956
"UMBRELAScore",
5057
"VoyageEmbeddingRequest",

vechord/model/web.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,40 @@ class RunRequest(msgspec.Struct, kw_only=True, frozen=True):
3535

3636

3737
class RunAck(msgspec.Struct, kw_only=True, frozen=True):
38+
"""Acknowledgment of an index request."""
39+
3840
name: str
3941
msg: str
4042
uid: UUID
43+
44+
45+
class SearchResponse(msgspec.Struct, kw_only=True):
46+
uid: UUID
47+
doc_id: UUID
48+
text: str
49+
50+
51+
class RunResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
52+
"""Response to a search request.
53+
54+
metrics:
55+
- MRR
56+
- precision@k
57+
- average precision@k
58+
"""
59+
60+
chunks: list[SearchResponse] = msgspec.field(default_factory=list)
61+
metrics: dict[str, float] = msgspec.field(default_factory=dict)
62+
63+
def extend(self, chunks: list):
64+
for chunk in chunks:
65+
self.chunks.append(
66+
SearchResponse(
67+
uid=chunk.uid,
68+
doc_id=chunk.doc_id,
69+
text=chunk.text,
70+
)
71+
)
72+
73+
def reorder(self, indices: list[int]):
74+
self.chunks = [self.chunks[i] for i in indices]

vechord/pipeline.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from vechord.entity import GeminiEntityRecognizer
2323
from vechord.errors import RequestError
24+
from vechord.evaluate import GeminiUMBRELAEvaluator
2425
from vechord.extract import GeminiExtractor, LlamaParseExtractor
2526
from vechord.model import (
2627
GraphEntity,
@@ -29,6 +30,7 @@
2930
ResourceRequest,
3031
RunAck,
3132
RunRequest,
33+
RunResponse,
3234
)
3335
from vechord.rerank import CohereReranker
3436
from vechord.spec import (
@@ -122,6 +124,7 @@ class _Relation(Table, kw_only=True):
122124
"graph": {"gemini": GeminiEntityRecognizer},
123125
"index": {"vectorchord": IndexOption},
124126
"search": {"vectorchord": SearchOption},
127+
"evaluate": {"gemini": GeminiUMBRELAEvaluator},
125128
}
126129

127130

@@ -160,6 +163,7 @@ class DynamicPipeline(msgspec.Struct, kw_only=True):
160163
index: Optional[IndexOption] = None
161164
search: Optional[SearchOption] = None
162165
graph: Optional[GeminiEntityRecognizer] = None
166+
evaluate: Optional[GeminiUMBRELAEvaluator] = None
163167

164168
def __post_init__(self):
165169
if not (self.text_emb or self.multimodal_emb):
@@ -195,7 +199,9 @@ def from_steps(cls, steps: list[ResourceRequest]) -> Self:
195199
calls[(step.kind).replace("-", "_")] = provider(**args)
196200
return msgspec.convert(calls, DynamicPipeline)
197201

198-
async def run(self, request: RunRequest, vr: "VechordRegistry"):
202+
async def run(
203+
self, request: RunRequest, vr: "VechordRegistry"
204+
) -> RunAck | RunResponse:
199205
"""Run the dynamic pipeline with the given request."""
200206
if self.index:
201207
return await self.run_index(request, vr)
@@ -374,7 +380,9 @@ async def graph_insert(
374380
rel.vec = await self.text_emb.vectorize_chunk(f"{rel.description}")
375381
await vr.insert(rel)
376382

377-
async def run_search(self, request: RunRequest, vr: "VechordRegistry"):
383+
async def run_search(
384+
self, request: RunRequest, vr: "VechordRegistry"
385+
) -> RunResponse:
378386
query = request.data.decode("utf-8")
379387

380388
# for type hint and compatibility
@@ -387,32 +395,34 @@ class Entity(_Entity):
387395
class Relation(_Relation):
388396
pass
389397

390-
retrieved: list[Chunk] = []
398+
resp = RunResponse()
391399
if self.search.vector:
392400
vec = (
393401
await self.text_emb.vectorize_query(query)
394402
if self.text_emb
395403
else await self.multimodal_emb.vectorize_multimodal_query(text=query)
396404
)
397-
retrieved.extend(
405+
resp.extend(
398406
await vr.search_by_vector(
399407
Chunk, vec, self.search.vector.topk, probe=self.search.vector.probe
400408
)
401409
)
402410
if self.search.keyword:
403-
retrieved.extend(
411+
resp.extend(
404412
await vr.search_by_keyword(Chunk, query, self.search.keyword.topk)
405413
)
406414
if self.search.graph:
407-
retrieved.extend(
408-
await self.graph_search(query, Chunk, Entity, Relation, vr)
409-
)
415+
resp.extend(await self.graph_search(query, Chunk, Entity, Relation, vr))
410416
if self.rerank:
411-
indices = await self.rerank.rerank(
412-
query, [chunk.text for chunk in retrieved]
417+
indices = await self.rerank.rerank(query, [chunk.text for chunk in resp])
418+
resp.reorder(indices)
419+
420+
if self.evaluate:
421+
resp.metrics = await self.evaluate.evaluate_with_estimation(
422+
query, [chunk.text for chunk in resp.chunks]
413423
)
414-
retrieved = [retrieved[i] for i in indices]
415-
return retrieved
424+
425+
return resp
416426

417427
async def graph_search(
418428
self,

0 commit comments

Comments
 (0)