Skip to content

Commit eeb65e8

Browse files
refactor to functions
Signed-off-by: greg pereira <[email protected]>
1 parent 8819b65 commit eeb65e8

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ generated
1313
.idea
1414
.DS_Store
1515
milvus/seed/data/*
16+
milvus/build/volumes/milvus/*data*
1617
*.venv
1718
*venv
1819

milvus/seed/seed.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,76 @@
11
import os
22
from pymilvus import MilvusClient, DataType
3+
from langchain_community.vectorstores import Milvus
34
from langchain_experimental.text_splitter import SemanticChunker
45
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader
56
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings
7+
from langchain.text_splitter import RecursiveCharacterTextSplitter
8+
from langchain import hub
9+
from langchain_core.runnables import RunnablePassthrough
10+
from langchain_core.output_parsers import StrOutputParser
611
from tika import parser # pip install tika
712

13+
814
def log_step(step_num, step_name) -> None:
915
print("-----------------------------------------------")
1016
print(f"{step_num}. {step_name}")
1117
print("-----------------------------------------------")
1218

13-
# model_name = "ibm/merlinite-7b"
14-
# model_kwargs = {"device": "cpu"}
15-
# encode_kwargs = {"normalize_embeddings": True}
19+
def milvus_init() -> MilvusClient:
20+
client = MilvusClient()
21+
if not client.has_connection('dnd'):
22+
client.drop_connection('dnd')
23+
return client
1624

17-
model_name = "ibm/merlinite-7b"
18-
model_kwargs={"device": "cuda"},
19-
encode_kwargs = {"device": "cuda", "batch_size": 100, "normalize_embeddings": True}
25+
def fill_dnd_collection(text_splitter: any, embeddings: any) -> None:
26+
# local
27+
raw = parser.from_file("data/DnD-5e-Handbook.pdf")
28+
print(len(raw['content']))
29+
docs = text_splitter.create_documents([raw['content']])
30+
vector_store = Milvus.from_documents(
31+
docs,
32+
embedding=embeddings,
33+
connection_args={"host": "localhost", "port": 19530},
34+
collection_name="dnd"
35+
)
36+
# remote
37+
# loader = PyPDFLoader('https://orkerhulen.dk/onewebmedia/DnD%205e%20Players%20Handbook%20%28BnW%20OCR%29.pdf')
38+
# data = loader.load()
2039

21-
log_step(0, "Generate embeddings")
22-
embeddings = HuggingFaceBgeEmbeddings(
23-
model_name=model_name,
24-
model_kwargs=model_kwargs,
25-
encode_kwargs=encode_kwargs,
26-
query_instruction = "search_query:",
27-
embed_instruction = "search_document:"
28-
)
40+
def generate_embeddings() -> any:
41+
# model_name = "ibm/merlinite-7b"
42+
# model_kwargs={"device": "cuda"},
43+
# encode_kwargs = {"device": "cuda", "batch_size": 100, "normalize_embeddings": True}
44+
model_name = "all-MiniLM-L6-v2"
45+
model_kwargs = {"device": "cpu"}
46+
encode_kwargs = {"normalize_embeddings": True}
47+
embeddings = HuggingFaceBgeEmbeddings(
48+
model_name=model_name,
49+
# model_kwargs=model_kwargs,
50+
encode_kwargs=encode_kwargs,
51+
query_instruction = "search_query:",
52+
embed_instruction = "search_document:"
53+
)
54+
55+
def generate_text_splitter(chunk_size=512, chunk_overlap=50) -> any:
56+
# text_splitter = SemanticChunker(embeddings=embeddings) # fails
57+
text_splitter = RecursiveCharacterTextSplitter(
58+
chunk_size=chunk_size,
59+
chunk_overlap=chunk_overlap,
60+
length_function=len,
61+
is_separator_regex=False
62+
)
63+
return text_splitter
2964

65+
log_step(0, "Generate embeddings")
66+
embeddings = generate_embeddings()
3067
log_step(1, "Init text splitter")
31-
text_splitter = SemanticChunker(embeddings=embeddings)
68+
text_splitter = generate_text_splitter()
3269
log_step(2, "Read Raw data from PDF")
33-
raw = parser.from_file("data/DnD-5e-Handbook.pdf")
3470
log_step(3, "Text splitting")
35-
print(len(raw['content']))
36-
docs = text_splitter.create_documents([raw['content']])
3771
log_step(4, "Log result")
38-
print(len(docs))
72+
fill_dnd_collection(embeddings=embeddings, text_splitter=text_splitter)
73+
74+
75+
# retreiver = vector_store.as_retreiver()
76+
# prompt = hub.pull("rlm/rag-prompt")

0 commit comments

Comments
 (0)