Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

860 changes: 860 additions & 0 deletions .idea/caches/deviceStreaming.xml

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions .idea/genai-stack.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 3 additions & 8 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from langchain_community.graphs import Neo4jGraph
from langchain_neo4j import Neo4jGraph
from dotenv import load_dotenv
from utils import (
create_vector_index,
Expand Down Expand Up @@ -128,10 +128,7 @@ def qstream(question: Question = Depends()):
q = Queue()

def cb():
output_function(
{"question": question.text, "chat_history": []},
callbacks=[QueueCallback(q)],
)
output_function.invoke(question.text, config={"callbacks": [QueueCallback(q)]})

def generate():
yield json.dumps({"init": True, "model": llm_name})
Expand All @@ -146,9 +143,7 @@ async def ask(question: Question = Depends()):
output_function = llm_chain
if question.rag:
output_function = rag_chain
result = output_function(
{"question": question.text, "chat_history": []}, callbacks=[]
)
result = output_function.invoke(question.text)

return {"result": result["answer"], "model": llm_name}

Expand Down
10 changes: 5 additions & 5 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain_community.graphs import Neo4jGraph
from langchain_neo4j import Neo4jGraph
from dotenv import load_dotenv
from utils import (
create_vector_index,
Expand Down Expand Up @@ -92,10 +92,10 @@ def chat_input():
with st.chat_message("assistant"):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
output = output_function.invoke(
user_input, config={"callbacks": [stream_handler]}
)

st.session_state[f"user_input"].append(user_input)
st.session_state[f"generated"].append(output)
st.session_state[f"rag_mode"].append(name)
Expand Down
73 changes: 38 additions & 35 deletions chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_aws import BedrockEmbeddings
Expand All @@ -8,21 +7,30 @@
from langchain_ollama import ChatOllama
from langchain_aws import ChatBedrock

from langchain_community.vectorstores import Neo4jVector
from langchain_neo4j import Neo4jVector

from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate
SystemMessagePromptTemplate,
)

from typing import List, Any
from utils import BaseLogger, extract_title_and_question
from utils import BaseLogger, extract_title_and_question, format_docs
from langchain_google_genai import GoogleGenerativeAIEmbeddings

AWS_MODELS = (
"ai21.jamba-instruct-v1:0",
"amazon.titan",
"anthropic.claude",
"cohere.command",
"meta.llama",
"mistral.mi",
)


def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
Expand All @@ -39,10 +47,8 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=
embeddings = BedrockEmbeddings()
dimension = 1536
logger.info("Embedding: Using AWS")
elif embedding_model_name == "google-genai-embedding-001":
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001"
)
elif embedding_model_name == "google-genai-embedding-001":
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
dimension = 768
logger.info("Embedding: Using Google Generative AI Embeddings")
else:
Expand All @@ -55,9 +61,9 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=


def load_llm(llm_name: str, logger=BaseLogger(), config={}):
if llm_name == "gpt-4":
if llm_name in ["gpt-4", "gpt-4o", "gpt-4-turbo"]:
logger.info("LLM: Using GPT-4")
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
return ChatOpenAI(temperature=0, model_name=llm_name, streaming=True)
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
Expand All @@ -68,6 +74,14 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)
elif llm_name.startswith(AWS_MODELS):
logger.info(f"LLM: {llm_name}")
return ChatBedrock(
model_id=llm_name,
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)

elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
Expand Down Expand Up @@ -96,17 +110,8 @@ def configure_llm_only_chain(llm):
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
chain = prompt | llm
answer = chain.invoke(
{"question": user_input}, config={"callbacks": callbacks}
).content
return {"answer": answer}

return generate_llm_output
chain = chat_prompt | llm | StrOutputParser()
return chain


def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
Expand Down Expand Up @@ -136,12 +141,6 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
]
qa_prompt = ChatPromptTemplate.from_messages(messages)

qa_chain = load_qa_with_sources_chain(
llm,
chain_type="stuff",
prompt=qa_prompt,
)

# Vector + Knowledge Graph response
kg = Neo4jVector.from_existing_index(
embedding=embeddings,
Expand All @@ -167,12 +166,16 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
ORDER BY similarity ASC // so that best answers are the last
""",
)

kg_qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
kg_qa = (
RunnableParallel(
{
"summaries": kg.as_retriever(search_kwargs={"k": 2}) | format_docs,
"question": RunnablePassthrough(),
}
)
| qa_prompt
| llm
| StrOutputParser()
)
return kg_qa

Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ services:

database:
user: neo4j:neo4j
image: neo4j:5.23
image: neo4j:5.26
ports:
- 7687:7687
- 7474:7474
Expand Down
2 changes: 1 addition & 1 deletion env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#*****************************************************************
# LLM and Embedding Model
#*****************************************************************
LLM=llama2 #or any Ollama model tag, gpt-4, gpt-3.5, or claudev2
LLM=llama2 #or any Ollama model tag, gpt-4 (o or turbo), gpt-3.5, or any bedrock model
EMBEDDING_MODEL=sentence_transformer #or google-genai-embedding-001 openai, ollama, or aws

#*****************************************************************
Expand Down
2 changes: 2 additions & 0 deletions front-end/.vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

{

"recommendations": ["svelte.svelte-vscode"]
}
4 changes: 1 addition & 3 deletions loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import requests
from dotenv import load_dotenv
from langchain_community.graphs import Neo4jGraph
from langchain_neo4j import Neo4jGraph
import streamlit as st
from streamlit.logger import get_logger
from chains import load_embedding_model
Expand All @@ -15,8 +15,6 @@
password = os.getenv("NEO4J_PASSWORD")
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
# Remapping for Langchain Neo4j integration
os.environ["NEO4J_URL"] = url

logger = get_logger(__name__)

Expand Down
30 changes: 25 additions & 5 deletions pdf_bot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import os

import streamlit as st
from langchain.chains import RetrievalQA
from PyPDF2 import PdfReader
from langchain.callbacks.base import BaseCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Neo4jVector
from langchain.prompts import ChatPromptTemplate
from langchain_neo4j import Neo4jVector
from streamlit.logger import get_logger
from chains import (
load_embedding_model,
load_llm,
)
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from utils import format_docs

# load api key lib
from dotenv import load_dotenv
Expand Down Expand Up @@ -67,6 +70,14 @@ def main():
)

chunks = text_splitter.split_text(text=text)
qa_prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"Based on the provided summary: {summaries} \n Answer the following question:{question}",
)
]
)

# Store the chunks part in db (vector)
vectorstore = Neo4jVector.from_texts(
Expand All @@ -79,16 +90,25 @@ def main():
node_label="PdfBotChunk",
pre_delete_collection=True, # Delete existing PDF data
)
qa = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever()
qa = (
RunnableParallel(
{
"summaries": vectorstore.as_retriever(search_kwargs={"k": 2})
| format_docs,
"question": RunnablePassthrough(),
}
)
| qa_prompt
| llm
| StrOutputParser()
)

# Accept user questions/query
query = st.text_input("Ask questions about your PDF file")

if query:
stream_handler = StreamHandler(st.empty())
qa.run(query, callbacks=[stream_handler])
qa.invoke(query, {"callbacks": [stream_handler]})


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion pull_model.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ COPY <<EOF pull_model.clj
(let [llm (get (System/getenv) "LLM")
url (get (System/getenv) "OLLAMA_BASE_URL")]
(println (format "pulling ollama model %s using %s" llm url))
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))
(if (and llm
url
(not (#{"gpt-4" "gpt-3.5" "claudev2" "gpt-4o" "gpt-4-turbo"} llm))
(not (some #(.startsWith llm %) ["ai21.jamba-instruct-v1:0"
"amazon.titan"
"anthropic.claude"
"cohere.command"
"meta.llama"
"mistral.mi"])))

;; ----------------------------------------------------------------------
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL
Expand Down
13 changes: 7 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ sse-starlette
boto3
streamlit==1.32.1
# missing from the langchain base image?
langchain-openai==0.2.4
langchain-community==0.3.3
langchain-google-genai==2.0.3
langchain-ollama==0.2.0
langchain-huggingface==0.1.1
langchain-aws==0.2.4
langchain-openai==0.3.8
langchain-community==0.3.19
langchain-google-genai==2.0.11
langchain-ollama==0.2.3
langchain-huggingface==0.1.2
langchain-aws==0.2.15
langchain-neo4j==0.4.0
8 changes: 7 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def create_vector_index(driver) -> None:
driver.query(index_query)
except: # Already exists
pass
index_query = "CREATE VECTOR INDEX top_answers IF NOT EXISTS FOR (m:Answer) ON m.embedding"
index_query = (
"CREATE VECTOR INDEX top_answers IF NOT EXISTS FOR (m:Answer) ON m.embedding"
)
try:
driver.query(index_query)
except: # Already exists
Expand All @@ -52,3 +54,7 @@ def create_constraints(driver):
driver.query(
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
)


def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)