-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathbase_1.py
More file actions
147 lines (124 loc) · 5.3 KB
/
base_1.py
File metadata and controls
147 lines (124 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from abc import ABC
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, get_response_synthesizer
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.postprocessor import SimilarityPostprocessor
from llama_index.node_parser import SimpleNodeParser
from llama_index import download_loader
from llama_index.embeddings import LangchainEmbedding
from llama_index import ServiceContext, StorageContext
from langchain.schema.embeddings import Embeddings
from llama_index.vector_stores import MilvusVectorStore
import os
from llama_index.data_structs import Node
import json
class BaseRetriever(ABC):
def __init__(
self,
docs_directory: str,
embed_model: Embeddings,
embed_dim: int = 768,
chunk_size: int = 128,
chunk_overlap: int = 0,
collection_name: str = "docs",
construct_index: bool = False,
add_index: bool = False,
similarity_top_k: int=2,
):
self.docs_directory = docs_directory
self.embed_model = embed_model
self.embed_dim = embed_dim
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.collection_name = collection_name
self.similarity_top_k = similarity_top_k
if construct_index:
self.construct_index()
else:
self.load_index_from_milvus()
if add_index:
self.add_index()
# self.query_engine = self.vector_index.as_query_engine()
retriever = VectorIndexRetriever(
index=self.vector_index,
similarity_top_k=self.similarity_top_k,
)
# assemble query engine
self.query_engine = RetrieverQueryEngine(
retriever=retriever,
)
def construct_index(self):
folder_path = self.docs_directory
nodes=[]
with open(folder_path, 'r', encoding='utf-8') as file:
qa_data = json.load(file)
for i in qa_data:
if not isinstance(i, str):
continue
if len(i)<10:
continue
node1 = Node(text=i)
nodes.append(node1)
self.embed_model = LangchainEmbedding(self.embed_model)
service_context = ServiceContext.from_defaults(
embed_model=self.embed_model,llm=None,
)
vector_store = MilvusVectorStore(
dim=self.embed_dim, overwrite=True,
collection_name=self.collection_name
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Process and index nodes in chunks due to Milvus limitations
for spilt_ids in range(0, len(nodes), 1000):
self.vector_index = GPTVectorStoreIndex(
nodes[spilt_ids:spilt_ids+1000], service_context=service_context,
storage_context=storage_context, show_progress=True
)
print(f"Indexing of part {spilt_ids} finished!")
vector_store = MilvusVectorStore(
overwrite=False,
collection_name=self.collection_name
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
print("Indexing finished!")
def add_index(self):
if self.docs_type == 'json':
JSONReader = download_loader("JSONReader")
documents = JSONReader().load_data(self.docs_directory)
else:
documents = SimpleDirectoryReader(self.docs_directory).load_data()
node_parser = SimpleNodeParser.from_defaults(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
self.embed_model = LangchainEmbedding(self.embed_model)
service_context = ServiceContext.from_defaults(
embed_model=self.embed_model,llm=None,
)
vector_store = MilvusVectorStore(
overwrite=False,
collection_name=self.collection_name
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Process and index nodes in chunks due to Milvus limitations
for spilt_ids in range(0, len(nodes), 8000):
self.vector_index = GPTVectorStoreIndex(
nodes[spilt_ids:spilt_ids+8000], service_context=service_context,
storage_context=storage_context, show_progress=True
)
print(f"Indexing of part {spilt_ids} finished!")
print("Indexing finished!")
def load_index_from_milvus(self):
vector_store = MilvusVectorStore(
overwrite=False, dim=self.embed_dim,
collection_name=self.collection_name
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=self.embed_model, llm=None)
self.vector_index = GPTVectorStoreIndex(
[], storage_context=storage_context,
service_context=service_context,
)
def search_docs(self, query_text: str):
response_vector = self.query_engine.query(query_text)
return response_vector.response