Skip to content

Commit 5e8bb89

Browse files
authored
fix: force llm to generate descriptions for entity and relation (#63)
Signed-off-by: Keming <kemingyang@tensorchord.ai>
1 parent 28a7cbd commit 5e8bb89

File tree

5 files changed

+83
-76
lines changed

5 files changed

+83
-76
lines changed

vechord/embedding.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ async def vectorize_query(self, text: str) -> np.ndarray:
5454
class BaseMultiModalEmbedding(BaseEmbedding):
5555
@abstractmethod
5656
async def vectorize_multimodal_chunk(
57-
self, text: str, image: Optional[bytes] = None, image_url: Optional[str] = None
57+
self,
58+
text: Optional[str] = None,
59+
image: Optional[bytes] = None,
60+
image_url: Optional[str] = None,
5861
) -> np.ndarray:
5962
raise NotImplementedError
6063

6164
async def vectorize_multimodal_query(
62-
self, text: str, image: Optional[bytes] = None, image_url: Optional[str] = None
65+
self,
66+
text: Optional[str] = None,
67+
image: Optional[bytes] = None,
68+
image_url: Optional[str] = None,
6369
) -> np.ndarray:
6470
return await self.vectorize_multimodal_chunk(text, image, image_url)
6571

@@ -184,7 +190,10 @@ def name(self) -> str:
184190
return f"jina_emb_{self.model}_{self.dim}"
185191

186192
async def vectorize_multimodal_chunk(
187-
self, text: str, image: Optional[bytes] = None, image_url: Optional[str] = None
193+
self,
194+
text: Optional[str] = None,
195+
image: Optional[bytes] = None,
196+
image_url: Optional[str] = None,
188197
) -> np.ndarray:
189198
req = await self.query(
190199
JinaEmbeddingRequest.from_text_image(
@@ -198,7 +207,10 @@ async def vectorize_multimodal_chunk(
198207
return req.get_emb()
199208

200209
async def vectorize_multimodal_query(
201-
self, text: str, image: Optional[bytes] = None, image_url: Optional[str] = None
210+
self,
211+
text: Optional[str] = None,
212+
image: Optional[bytes] = None,
213+
image_url: Optional[str] = None,
202214
) -> np.ndarray:
203215
req = await self.query(
204216
JinaEmbeddingRequest.from_text_image(
@@ -268,8 +280,8 @@ def vec_type(self) -> VecType:
268280

269281
async def vectorize_multimodal_chunk(
270282
self,
271-
image: Optional[bytes] = None,
272283
text: Optional[str] = None,
284+
image: Optional[bytes] = None,
273285
image_url: Optional[str] = None,
274286
):
275287
resp = await self.query(
@@ -285,8 +297,8 @@ async def vectorize_multimodal_chunk(
285297

286298
async def vectorize_multimodal_query(
287299
self,
288-
image: Optional[bytes] = None,
289300
text: Optional[str] = None,
301+
image: Optional[bytes] = None,
290302
image_url: Optional[str] = None,
291303
):
292304
resp = await self.query(

vechord/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ def recognize_with_relations(
134134
Entity could be person, location, org, event or category.
135135
"""
136136
RECOGNIZE_PROMPT_FIELD = """\n<document>\n{text}\n</document>\n"""
137+
RECOGNIZE_PROMPT_IMAGE = """
138+
Extract the readable text and generate a concise caption describing the image's content
139+
or scene. Use the text and caption as the passage text for named entity extraction.
140+
"""
137141

138142

139143
class GeminiEntityRecognizer(BaseEntityRecognizer, GeminiGenerateProvider):
@@ -184,13 +188,9 @@ async def recognize_image(
184188
self, img: bytes
185189
) -> tuple[list[GraphEntity], list[GraphRelation]]:
186190
"""Recognize entities & relations from the image."""
187-
prompt = (
188-
"Given the image, first summarize it and extract readable text."
189-
f"{self.prompt}"
190-
)
191191
resp = await self.query(
192192
GeminiGenerateRequest.from_prompt_data_structure_resp(
193-
prompt=prompt,
193+
prompt=self.prompt.format(text=RECOGNIZE_PROMPT_IMAGE),
194194
mime_type=GeminiMimeType.JPEG,
195195
data=img,
196196
schema=msgspec.json.schema(list[GraphRelation]),

vechord/model/internal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class GraphEntity(msgspec.Struct, kw_only=True, frozen=True):
1414
"""
1515

1616
text: str
17-
label: str = ""
18-
description: str = ""
17+
label: str
18+
description: str
1919

2020

2121
class GraphRelation(msgspec.Struct, kw_only=True, frozen=True):
@@ -28,7 +28,7 @@ class GraphRelation(msgspec.Struct, kw_only=True, frozen=True):
2828

2929
source: GraphEntity
3030
target: GraphEntity
31-
description: str = ""
31+
description: str
3232

3333

3434
class Document(msgspec.Struct, kw_only=True):

vechord/model/voyage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ class VoyageMultiModalEmbeddingRequest(msgspec.Struct, kw_only=True):
8484
def build(
8585
cls,
8686
text: Optional[str],
87-
image_url: Optional[str],
8887
image: Optional[bytes],
88+
image_url: Optional[str],
8989
model: str,
9090
input_type: VOYAGE_INPUT_TYPE,
9191
) -> Self:

vechord/pipeline.py

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
RunRequest,
3838
RunResponse,
3939
)
40-
from vechord.rerank import CohereReranker, JinaReranker
40+
from vechord.rerank import BaseReranker, CohereReranker, JinaReranker
4141
from vechord.spec import (
4242
AnyOf,
4343
DefaultDocument,
@@ -167,7 +167,7 @@ class DynamicPipeline(msgspec.Struct, kw_only=True):
167167
text_emb: Optional[BaseTextEmbedding] = None
168168
multimodal_emb: Optional[BaseMultiModalEmbedding] = None
169169
ocr: Optional[GeminiExtractor] = None
170-
rerank: Optional[CohereReranker] = None
170+
rerank: Optional[BaseReranker] = None
171171
index: Optional[IndexOption] = None
172172
search: Optional[SearchOption] = None
173173
graph: Optional[GeminiEntityRecognizer] = None
@@ -280,71 +280,58 @@ async def run_index(self, request: RunRequest, vr: "VechordRegistry") -> RunAck:
280280
doc_id=doc.uid,
281281
text="",
282282
text_type=request.input_type.value,
283-
Keyword=None,
283+
keyword=None,
284284
vec=None,
285285
)
286-
if self.multimodal_emb:
287-
chunks.append(
288-
Chunk(
289-
doc_id=doc.uid,
290-
text=base64.b64encode(request.data).decode("utf-8"),
291-
text_type=request.input_type.value,
292-
keyword=None,
293-
vec=await self.multimodal_emb.vectorize_multimodal_chunk(
294-
request.data
295-
),
296-
)
286+
if self.multimodal_emb and request.input_type is not InputType.TEXT:
287+
# reuse the fake chunk to ensure the chunk uid is unique
288+
fake_chunk.text = base64.b64encode(request.data).decode("utf-8")
289+
fake_chunk.vec = await self.multimodal_emb.vectorize_multimodal_chunk(
290+
image=request.data
291+
)
292+
chunks.append(fake_chunk)
293+
if request.input_type is InputType.TEXT:
294+
doc.text = request.data.decode("utf-8")
295+
elif self.ocr:
296+
if request.input_type is InputType.PDF:
297+
doc.text = await self.ocr.extract_pdf(request.data)
298+
elif request.input_type is InputType.IMAGE:
299+
doc.text = await self.ocr.extract_image(request.data)
300+
if self.chunk:
301+
sentences.extend(await self.chunk.segment(doc.text))
302+
elif doc.text:
303+
sentences.append(doc.text)
304+
305+
for sent in sentences:
306+
chunk = Chunk(
307+
vec=await self.text_emb.vectorize_chunk(sent),
308+
doc_id=doc.uid,
309+
text=sent,
310+
keyword=None if not enable_keyword_index else Keyword(sent),
297311
)
298-
else:
299-
if request.input_type is InputType.TEXT:
300-
doc.text = request.data.decode("utf-8")
301-
elif self.ocr:
302-
if request.input_type is InputType.PDF:
303-
doc.text = await self.ocr.extract_pdf(request.data)
304-
elif request.input_type is InputType.IMAGE:
305-
doc.text = await self.ocr.extract_image(request.data)
306-
elif self.graph:
307-
fake_chunk.text = base64.b64encode(request.data).decode("utf-8")
308-
img_ents, img_rels = await self.graph.recognize_image(request.data)
312+
chunks.append(chunk)
313+
if self.graph and request.input_type is InputType.TEXT:
314+
chunk_ents, chunk_rels = await self.graph.recognize_with_relations(sent)
309315
conv_ents, conv_rels = self._convert_from_extracted_graph(
310-
fake_chunk.uid, img_ents, img_rels, Entity, Relation
316+
chunk.uid, chunk_ents, chunk_rels, Entity, Relation
311317
)
312318
ents.extend(conv_ents)
313319
rels.extend(conv_rels)
314-
else:
315-
raise RequestError(
316-
f"No OCR or Graph provider for input type: {request.input_type}"
317-
)
318320

319-
if self.chunk:
320-
sentences.extend(await self.chunk.segment(doc.text))
321-
elif doc.text:
322-
sentences.append(doc.text)
323-
for sent in sentences:
324-
chunk = Chunk(
325-
vec=await self.text_emb.vectorize_chunk(sent),
326-
doc_id=doc.uid,
327-
text=sent,
328-
keyword=None if not enable_keyword_index else Keyword(sent),
329-
)
330-
if self.graph and request.input_type is InputType.TEXT:
331-
chunk_ents, chunk_rels = await self.graph.recognize_with_relations(
332-
sent
333-
)
334-
conv_ents, conv_rels = self._convert_from_extracted_graph(
335-
chunk.uid, chunk_ents, chunk_rels, Entity, Relation
336-
)
337-
ents.extend(conv_ents)
338-
rels.extend(conv_rels)
339-
chunks.append(chunk)
321+
if self.graph and request.input_type is not InputType.TEXT and not sentences:
322+
img_ents, img_rels = await self.graph.recognize_image(request.data)
323+
conv_ents, conv_rels = self._convert_from_extracted_graph(
324+
fake_chunk.uid, img_ents, img_rels, Entity, Relation
325+
)
326+
ents.extend(conv_ents)
327+
rels.extend(conv_rels)
328+
if not self.multimodal_emb:
329+
chunks.append(fake_chunk)
340330

341331
await vr.insert(doc)
342332
for chunk in chunks:
343333
await vr.insert(chunk)
344334
if self.index.graph:
345-
if request.input_type is not InputType.TEXT:
346-
# insert the fake chunk for image/pdf
347-
await vr.insert(fake_chunk)
348335
await self.graph_insert(
349336
ents=ents, rels=rels, ent_cls=Entity, rel_cls=Relation, vr=vr
350337
)
@@ -360,6 +347,11 @@ async def graph_insert(
360347
):
361348
"""Insert entities and relations into the graph index."""
362349
ent_map: dict[str, _Entity] = {}
350+
emb_func = (
351+
self.text_emb.vectorize_chunk
352+
if self.text_emb
353+
else self.multimodal_emb.vectorize_multimodal_chunk
354+
)
363355
for ent in ents:
364356
if ent.text not in ent_map:
365357
ent_map[ent.text] = ent
@@ -376,9 +368,7 @@ async def graph_insert(
376368
ent.chunk_uuids.extend(exist.chunk_uuids)
377369
ent.description += f"\n{exist.description}"
378370
await vr.remove_by(ent_cls.partial_init(uid=exist.uid))
379-
ent.vec = await self.text_emb.vectorize_chunk(
380-
f"{ent.text}\n{ent.description}"
381-
)
371+
ent.vec = await emb_func(f"{ent.text}\n{ent.description}")
382372
await vr.insert(ent)
383373

384374
relation_map: dict[str, _Relation] = {}
@@ -397,7 +387,7 @@ async def graph_insert(
397387
exist = exist_rel[0]
398388
rel.description += f"\n{exist.description}"
399389
await vr.remove_by(rel_cls.partial_init(uid=exist.uid))
400-
rel.vec = await self.text_emb.vectorize_chunk(f"{rel.description}")
390+
rel.vec = await emb_func(f"{rel.description}")
401391
await vr.insert(rel)
402392

403393
async def run_search(
@@ -437,7 +427,7 @@ class Relation(_Relation):
437427
if self.multimodal_emb:
438428
indices = await self.rerank.rerank_multimodal(
439429
query=query,
440-
chunks=[chunk.text for chunk in resp],
430+
chunks=[chunk.text for chunk in resp.chunks],
441431
doc_type=resp.chunk_type,
442432
)
443433
else:
@@ -461,11 +451,16 @@ async def graph_search(
461451
vr: "VechordRegistry",
462452
):
463453
ents, rels = await self.graph.recognize_with_relations(query)
454+
emb_func = (
455+
self.text_emb.vectorize_query
456+
if self.text_emb
457+
else self.multimodal_emb.vectorize_multimodal_query
458+
)
464459
if rels:
465460
rel_text = " ".join(rel.description for rel in rels)
466461
similar_rels = await vr.search_by_vector(
467462
rel_cls,
468-
await self.text_emb.vectorize_query(rel_text),
463+
await emb_func(rel_text),
469464
topk=self.search.graph.similar_k,
470465
)
471466
ent_uuids = deduplicate_uid(
@@ -484,7 +479,7 @@ async def graph_search(
484479
ent_text = " ".join(f"{ent.text} {ent.description}" for ent in ents)
485480
similar_ents = await vr.search_by_vector(
486481
ent_cls,
487-
await self.text_emb.vectorize_query(ent_text),
482+
await emb_func(ent_text),
488483
topk=self.search.graph.similar_k,
489484
)
490485
chunk_uuids = deduplicate_uid(

0 commit comments

Comments
 (0)