Skip to content

Commit 972dc45

Browse files
authored
feat: add bm25 keyword search and cohere rerank (#10)
* feat: add bm25 keyword search Signed-off-by: Keming <kemingyang@tensorchord.ai> * add cohere rerank Signed-off-by: Keming <kemingyang@tensorchord.ai> --------- Signed-off-by: Keming <kemingyang@tensorchord.ai>
1 parent dbd4cc8 commit 972dc45

File tree

19 files changed

+457
-59
lines changed

19 files changed

+457
-59
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Vechord
2+
13
Python RAG framework built on top of PostgreSQL and [VectorChord](https://github.com/tensorchord/VectorChord/).
24

35
## Installation
@@ -16,7 +18,7 @@ pip install vechord
1618
## Development
1719

1820
```bash
19-
docker run --rm -d --name vechord -e POSTGRES_PASSWORD=postgres -p 5432:5432 tensorchord/vchord-postgres:pg17-v0.2.1
21+
docker run --rm -d --name vdb -e POSTGRES_PASSWORD=postgres -p 5432:5432 ghcr.io/tensorchord/vchord_bm25-postgres:pg17-v0.1.1
2022
envd up
2123
# inside the envd env, sync all the dependencies
2224
make sync

docs/source/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@
6262
:show-inheritance:
6363
```
6464

65+
## Rerank
66+
67+
```{eval-rst}
68+
.. automodule:: vechord.rerank
69+
:members:
70+
:show-inheritance:
71+
```
72+
6573
## Service
6674

6775
```{eval-rst}

docs/source/example.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@
2323
```{include} ../../examples/essay.py
2424
:code: python
2525
```
26+
27+
## Hybrid search with rerank
28+
29+
```{include} ../../examples/hybrid.py
30+
:code: python
31+
```

examples/beir.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def load_corpus(dataset: str, output: Path) -> Iterator[Corpus]:
9090
uid=item["_id"],
9191
text=text,
9292
title=title,
93-
vector=vector,
93+
vector=DenseVector(vector),
9494
)
9595

9696

@@ -115,13 +115,16 @@ def load_query(dataset: str, output: Path) -> Iterator[Query]:
115115
continue
116116
text = item.get("text", "")
117117
yield Query(
118-
uid=uid, cid=table[uid], text=text, vector=emb.vectorize_query(text)
118+
uid=uid,
119+
cid=table[uid],
120+
text=text,
121+
vector=DenseVector(emb.vectorize_query(text)),
119122
)
120123

121124

122125
@vr.inject(input=Query)
123126
def evaluate(cid: str, vector: DenseVector) -> Evaluation:
124-
docs: list[Corpus] = vr.search(Corpus, vector, topk=TOP_K)
127+
docs: list[Corpus] = vr.search_by_vector(Corpus, vector, topk=TOP_K)
125128
score = BaseEvaluator.evaluate_one(cid, [doc.uid for doc in docs])
126129
return Evaluation(
127130
map=score.get("map"),

examples/contextual.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ def split_document(uid: int, text: str) -> list[Chunk]:
6767
chunker = RegexChunker(overlap=0)
6868
chunks = chunker.segment(text)
6969
return [
70-
Chunk(doc_uid=uid, seq_id=i, text=chunk, vector=emb.vectorize_chunk(chunk))
70+
Chunk(
71+
doc_uid=uid,
72+
seq_id=i,
73+
text=chunk,
74+
vector=DenseVector(emb.vectorize_chunk(chunk)),
75+
)
7176
for i, chunk in enumerate(chunks)
7277
]
7378

@@ -89,7 +94,9 @@ def context_embedding(uid: int, text: str) -> list[ContextChunk]:
8994
]
9095
return [
9196
ContextChunk(
92-
chunk_uid=chunk_uid, text=augmented, vector=emb.vectorize_chunk(augmented)
97+
chunk_uid=chunk_uid,
98+
text=augmented,
99+
vector=DenseVector(emb.vectorize_chunk(augmented)),
93100
)
94101
for (chunk_uid, augmented) in zip(
95102
[c.uid for c in chunks], context_chunks, strict=False
@@ -99,22 +106,20 @@ def context_embedding(uid: int, text: str) -> list[ContextChunk]:
99106

100107
def query_chunk(query: str) -> list[Chunk]:
101108
vector = emb.vectorize_query(query)
102-
res: list[Chunk] = vr.search(
109+
res: list[Chunk] = vr.search_by_vector(
103110
Chunk,
104111
vector,
105112
topk=5,
106-
return_vector=False,
107113
)
108114
return res
109115

110116

111117
def query_context_chunk(query: str) -> list[ContextChunk]:
112118
vector = emb.vectorize_query(query)
113-
res: list[ContextChunk] = vr.search(
119+
res: list[ContextChunk] = vr.search_by_vector(
114120
ContextChunk,
115121
vector,
116122
topk=5,
117-
return_vector=False,
118123
)
119124
return res
120125

@@ -125,7 +130,7 @@ def evaluate(uid: int, doc_uid: int, text: str):
125130
doc: Document = vr.select_by(Document.partial_init(uid=doc_uid))[0]
126131
query = evaluator.produce_query(doc.text, text)
127132
retrieved = query_chunk(query)
128-
score = evaluator.evaluate_one(uid, [r.uid for r in retrieved])
133+
score = evaluator.evaluate_one(str(uid), [str(r.uid) for r in retrieved])
129134
return score
130135

131136

examples/essay.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,22 @@ class Evaluation:
7575
def segment_essay() -> list[Chunk]:
7676
chunker = RegexChunker()
7777
chunks = chunker.segment(doc)
78-
return [Chunk(text=chunk, vector=emb.vectorize_chunk(chunk)) for chunk in chunks]
78+
return [
79+
Chunk(text=chunk, vector=DenseVector(emb.vectorize_chunk(chunk)))
80+
for chunk in chunks
81+
]
7982

8083

8184
@vr.inject(input=Chunk, output=Query)
8285
def create_query(uid: int, text: str) -> Query:
8386
query = evaluator.produce_query(doc, text)
84-
return Query(cid=uid, text=query, vector=emb.vectorize_chunk(query))
87+
return Query(cid=uid, text=query, vector=DenseVector(emb.vectorize_chunk(query)))
8588

8689

8790
@vr.inject(input=Query)
8891
def evaluate(cid: int, vector: DenseVector) -> Evaluation:
89-
chunks: list[Chunk] = vr.search(Chunk, vector, topk=TOP_K)
90-
score = evaluator.evaluate_one(cid, [chunk.uid for chunk in chunks])
92+
chunks: list[Chunk] = vr.search_by_vector(Chunk, vector, topk=TOP_K)
93+
score = evaluator.evaluate_one(str(cid), [str(chunk.uid) for chunk in chunks])
9194
return Evaluation(
9295
map=score["map"], ndcg=score["ndcg"], recall=score[f"recall_{TOP_K}"]
9396
)

examples/hybrid.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from html.parser import HTMLParser
2+
from typing import Annotated
3+
4+
import httpx
5+
6+
from vechord.chunk import RegexChunker
7+
from vechord.embedding import SpacyDenseEmbedding
8+
from vechord.registry import VechordRegistry
9+
from vechord.rerank import CohereReranker
10+
from vechord.spec import ForeignKey, Keyword, PrimaryKeyAutoIncrease, Table, Vector
11+
12+
URL = "https://paulgraham.com/{}.html"
13+
DenseVector = Vector[96]
14+
emb = SpacyDenseEmbedding()
15+
chunker = RegexChunker(size=1024, overlap=0)
16+
reranker = CohereReranker()
17+
18+
19+
class EssayParser(HTMLParser):
20+
def __init__(self, *, convert_charrefs: bool = ...) -> None:
21+
super().__init__(convert_charrefs=convert_charrefs)
22+
self.content = []
23+
self.skip = False
24+
25+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
26+
if tag in ("script", "style"):
27+
self.skip = True
28+
29+
def handle_endtag(self, tag: str) -> None:
30+
if tag in ("script", "style"):
31+
self.skip = False
32+
33+
def handle_data(self, data: str) -> None:
34+
if not self.skip:
35+
self.content.append(data.strip())
36+
37+
38+
class Document(Table, kw_only=True):
39+
uid: PrimaryKeyAutoIncrease | None = None
40+
title: str = ""
41+
text: str
42+
43+
44+
class Chunk(Table, kw_only=True):
45+
uid: PrimaryKeyAutoIncrease | None = None
46+
doc_id: Annotated[int, ForeignKey[Document.uid]]
47+
text: str
48+
vector: DenseVector
49+
keyword: Keyword
50+
51+
52+
vr = VechordRegistry("hybrid", "postgresql://postgres:postgres@172.17.0.1:5432/")
53+
vr.register([Document, Chunk])
54+
55+
56+
@vr.inject(output=Document)
57+
def load_document(title: str) -> Document:
58+
with httpx.Client() as client:
59+
resp = client.get(URL.format(title))
60+
if resp.is_error:
61+
raise RuntimeError(f"Failed to fetch the document `{title}`")
62+
parser = EssayParser()
63+
parser.feed(resp.text)
64+
return Document(title=title, text="\n".join(t for t in parser.content if t))
65+
66+
67+
@vr.inject(input=Document, output=Chunk)
68+
def chunk_document(uid: int, text: str) -> list[Chunk]:
69+
chunks = chunker.segment(text)
70+
return [
71+
Chunk(
72+
doc_id=uid,
73+
text=chunk,
74+
vector=emb.vectorize_chunk(chunk),
75+
keyword=Keyword(chunk),
76+
)
77+
for chunk in chunks
78+
]
79+
80+
81+
def search_and_rerank(query: str, topk: int) -> list[Chunk]:
82+
text_retrieves = vr.search_by_keyword(Chunk, query, topk=topk)
83+
vec_retrievse = vr.search_by_vector(Chunk, emb.vectorize_query(query), topk=topk)
84+
chunks = list(
85+
{chunk.uid: chunk for chunk in text_retrieves + vec_retrievse}.values()
86+
)
87+
indices = reranker.rerank(query, [chunk.text for chunk in chunks])
88+
return [chunks[i] for i in indices[:topk]]
89+
90+
91+
if __name__ == "__main__":
92+
load_document("smart")
93+
chunk_document()
94+
chunks = search_and_rerank("smart", 3)
95+
print(chunks)

examples/web.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def load_document(title: str) -> Document:
7474
def chunk_document(uid: int, text: str) -> list[Chunk]:
7575
chunks = chunker.segment(text)
7676
return [
77-
Chunk(doc_id=uid, text=chunk, vector=emb.vectorize_chunk(chunk))
77+
Chunk(doc_id=uid, text=chunk, vector=DenseVector(emb.vectorize_chunk(chunk)))
7878
for chunk in chunks
7979
]
8080

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ spacy = [
3434
wordllama = [
3535
"wordllama>=0.3.8.post20",
3636
]
37+
cohere = [
38+
"cohere>=5.14.0",
39+
]
3740

3841
[build-system]
3942
requires = ["pdm-backend"]

tests/test_spec.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Annotated
33

44
import msgspec
5+
import numpy as np
56
import pytest
67

78
from vechord.spec import ForeignKey, PrimaryKeyAutoIncrease, Table, Vector
@@ -51,3 +52,16 @@ def find_schema_by_name(schema, name):
5152
"REFERENCES {namespace}_document(uid) ON DELETE CASCADE"
5253
in find_schema_by_name(Chunk.table_schema(), "doc_id")
5354
)
55+
56+
57+
def test_vector_type():
58+
Dense = Vector[128]
59+
60+
# test the dim
61+
with pytest.raises(ValueError):
62+
Dense([0.1] * 100)
63+
64+
with pytest.raises(ValueError):
65+
Dense(np.random.rand(123))
66+
67+
assert np.equal(Dense(np.ones(128)), Dense([1.0] * 128)).all()

0 commit comments

Comments
 (0)