|
2 | 2 | import os |
3 | 3 | from dotenv import load_dotenv |
4 | 4 |
|
5 | | -from langchain.callbacks.base import AsyncCallbackManager |
| 5 | +from langchain.callbacks.manager import AsyncCallbackManager |
6 | 6 | from langchain.callbacks.tracers import LangChainTracer |
7 | | -from langchain.chains import ChatVectorDBChain |
| 7 | +from langchain.chains import ConversationalRetrievalChain |
8 | 8 | from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, |
9 | 9 | QA_PROMPT) |
10 | 10 | from langchain.chains.llm import LLMChain |
|
20 | 20 |
|
21 | 21 | def get_chain( |
22 | 22 | vectorstore: VectorStore, question_handler, stream_handler, tracing: bool = False |
23 | | -) -> ChatVectorDBChain: |
24 | | - """Create a ChatVectorDBChain for question/answering.""" |
25 | | - # Construct a ChatVectorDBChain with a streaming llm for combine docs |
| 23 | +) -> ConversationalRetrievalChain: |
| 24 | + """Create a ConversationalRetrievalChain for question/answering.""" |
| 25 | + # Construct a ConversationalRetrievalChain with a streaming llm for combine docs |
26 | 26 | # and a separate, non-streaming llm for question generation |
27 | 27 | manager = AsyncCallbackManager([]) |
28 | 28 | question_manager = AsyncCallbackManager([question_handler]) |
@@ -53,8 +53,8 @@ def get_chain( |
53 | 53 | streaming_llm, chain_type="stuff", prompt=QA_PROMPT, callback_manager=manager |
54 | 54 | ) |
55 | 55 |
|
56 | | - qa = ChatVectorDBChain( |
57 | | - vectorstore=vectorstore, |
| 56 | + qa = ConversationalRetrievalChain( |
| 57 | + retriever=vectorstore.as_retriever(), |
58 | 58 | combine_docs_chain=doc_chain, |
59 | 59 | question_generator=question_generator, |
60 | 60 | callback_manager=manager, |
|
0 commit comments