diff --git a/docker-compose.yml b/docker-compose.yml index d8a8ca75..f36517ef 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,19 +1,18 @@ services: db: - build: - context: ./db - dockerfile: Dockerfile - volumes: - - postgres_data:/var/lib/postgresql/data/ - environment: - - POSTGRES_USER=balancer - - POSTGRES_PASSWORD=balancer - - POSTGRES_DB=balancer_dev - ports: + image: pgvector/pgvector:pg15 + volumes: + - postgres_data:/var/lib/postgresql/data/ + - ./init-vector-extension.sql:/docker-entrypoint-initdb.d/init-vector-extension.sql + environment: + - POSTGRES_USER=balancer + - POSTGRES_PASSWORD=balancer + - POSTGRES_DB=balancer_dev + ports: - "5433:5432" - networks: - app_net: - ipv4_address: 192.168.0.2 + networks: + app_net: + ipv4_address: 192.168.0.2 pgadmin: container_name: pgadmin4 image: dpage/pgadmin4 @@ -52,13 +51,13 @@ services: args: - IMAGE_NAME=balancer-frontend ports: - - "3000:3000" + - "3000:3000" environment: - - CHOKIDAR_USEPOLLING=true - # - VITE_API_BASE_URL=https://balancertestsite.com/ + - CHOKIDAR_USEPOLLING=true + # - VITE_API_BASE_URL=https://balancertestsite.com/ volumes: - - "./frontend:/usr/src/app:delegated" - - "/usr/src/app/node_modules/" + - "./frontend:/usr/src/app:delegated" + - "/usr/src/app/node_modules/" depends_on: - backend networks: @@ -72,4 +71,4 @@ networks: driver: default config: - subnet: "192.168.0.0/24" - gateway: 192.168.0.1 \ No newline at end of file + gateway: 192.168.0.1 diff --git a/frontend/src/api/apiClient.ts b/frontend/src/api/apiClient.ts index 73b74caf..3e672f1e 100644 --- a/frontend/src/api/apiClient.ts +++ b/frontend/src/api/apiClient.ts @@ -267,6 +267,20 @@ const updateConversationTitle = async ( } }; +// Assistant API functions +const sendAssistantMessage = async (message: string, previousResponseId?: string) => { + try { + const response = await api.post(`/v1/api/assistant`, { + message, + previous_response_id: previousResponseId, + }); + return response.data; + } catch (error) { + console.error("Error(s) during sendAssistantMessage: ", error); + throw error; + } +}; + export { handleSubmitFeedback, handleSendDrugSummary, @@ -279,5 +293,6 @@ export { updateConversationTitle, handleSendDrugSummaryStream, handleSendDrugSummaryStreamLegacy, - fetchRiskDataWithSources + fetchRiskDataWithSources, + sendAssistantMessage }; \ No newline at end of file diff --git a/frontend/src/components/Header/Chat.tsx b/frontend/src/components/Header/Chat.tsx index df11f68b..2f4f92ab 100644 --- a/frontend/src/components/Header/Chat.tsx +++ b/frontend/src/components/Header/Chat.tsx @@ -4,12 +4,9 @@ import "../../components/Header/chat.css"; import { useState, useEffect, useRef } from "react"; import TypingAnimation from "./components/TypingAnimation"; import ErrorMessage from "../ErrorMessage"; -import ConversationList from "./ConversationList"; -import { extractContentFromDOM } from "../../services/domExtraction"; import axios from "axios"; import { FaPlus, - FaMinus, FaTimes, FaComment, FaComments, @@ -19,20 +16,15 @@ import { FaExpandAlt, FaExpandArrowsAlt, } from "react-icons/fa"; -import { - fetchConversations, - continueConversation, - newConversation, - updateConversationTitle, - deleteConversation, -} from "../../api/apiClient"; - -interface ChatLogItem { +import { sendAssistantMessage } from "../../api/apiClient"; + +export interface ChatLogItem { is_user: boolean; content: string; timestamp: string; // EX: 2025-01-16T16:21:14.981090Z } +// Keep interface for backward compatibility with existing imports export interface Conversation { title: string; messages: ChatLogItem[]; @@ -47,14 +39,35 @@ interface ChatDropDownProps { const Chat: React.FC = ({ showChat, setShowChat }) => { const CHATBOT_NAME = "JJ"; const [inputValue, setInputValue] = useState(""); - const [chatLog, setChatLog] = useState([]); // Specify the type as ChatLogItem[] + const [currentMessages, setCurrentMessages] = useState([]); + const [currentResponseId, setCurrentResponseId] = useState< + string | undefined + >(undefined); const [isLoading, setIsLoading] = useState(false); - const [showConversationList, setShowConversationList] = useState(false); - const [conversations, setConversations] = useState([]); - const [activeConversation, setActiveConversation] = - useState(null); const [error, setError] = useState(null); + // Session storage functions for conversation management + const saveConversationToStorage = (messages: ChatLogItem[], responseId?: string) => { + const conversationData = { + messages, + responseId, + timestamp: new Date().toISOString(), + }; + sessionStorage.setItem('currentConversation', JSON.stringify(conversationData)); + }; + + const loadConversationFromStorage = () => { + const stored = sessionStorage.getItem('currentConversation'); + if (stored) { + try { + return JSON.parse(stored); + } catch (error) { + console.error('Error parsing stored conversation:', error); + } + } + return null; + }; + const suggestionPrompts = [ "What are the side effects of Latuda?", "Why is cariprazine better than valproate for a pregnant patient?", @@ -63,27 +76,30 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { "Risks associated with Lithium.", "What medications could cause liver issues?", ]; - const [pageContent, setPageContent] = useState(""); const chatContainerRef = useRef(null); + const [bottom, setBottom] = useState(false); + + // Load conversation from sessionStorage on component mount useEffect(() => { - const observer = new MutationObserver(() => { - const content = extractContentFromDOM(); - setPageContent(content); - }); - - observer.observe(document.body, { - childList: true, - subtree: true, - characterData: true, - }); - - const extractedContent = extractContentFromDOM(); - // console.log(extractedContent); - setPageContent(extractedContent); + const storedConversation = loadConversationFromStorage(); + if (storedConversation) { + setCurrentMessages(storedConversation.messages || []); + setCurrentResponseId(storedConversation.responseId); + } }, []); - const [bottom, setBottom] = useState(false); + + // Save conversation to sessionStorage when component unmounts + useEffect(() => { + return () => { + // Only save if the user hasn't logged out + const isLoggingOut = !localStorage.getItem("access"); + if (!isLoggingOut && currentMessages.length > 0) { + saveConversationToStorage(currentMessages, currentResponseId); + } + }; + }, [currentMessages, currentResponseId]); const handleScroll = (event: React.UIEvent) => { const target = event.target as HTMLElement; @@ -96,34 +112,24 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { const [expandChat, setExpandChat] = useState(false); useEffect(() => { - if (chatContainerRef.current && activeConversation) { + if (chatContainerRef.current) { const chatContainer = chatContainerRef.current; // Use setTimeout to ensure the new message has been rendered setTimeout(() => { chatContainer.scrollTop = chatContainer.scrollHeight; setBottom( chatContainer.scrollHeight - chatContainer.scrollTop === - chatContainer.clientHeight + chatContainer.clientHeight, ); }, 0); } - }, [activeConversation?.messages]); - - const loadConversations = async () => { - try { - const data = await fetchConversations(); - setConversations(data); - // setLoading(false); - } catch (error) { - console.error("Error loading conversations: ", error); - } - }; + }, [currentMessages]); const scrollToBottom = (element: HTMLElement) => element.scroll({ top: element.scrollHeight, behavior: "smooth" }); const handleScrollDown = ( - event: React.MouseEvent + event: React.MouseEvent, ) => { event.preventDefault(); const element = document.getElementById("inside_chat"); @@ -137,72 +143,52 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { event: | React.FormEvent | React.MouseEvent, - suggestion?: string + suggestion?: string, ) => { event.preventDefault(); + const messageContent = (inputValue || suggestion) ?? ""; + if (!messageContent.trim()) return; + const newMessage = { - content: (inputValue || suggestion) ?? "", + content: messageContent, is_user: true, timestamp: new Date().toISOString(), }; - const newMessages = [...chatLog, newMessage]; - - setChatLog(newMessages); - - // sendMessage(newMessages); try { - let conversation = activeConversation; - let conversationCreated = false; - - // Create a new conversation if none exists - if (!conversation) { - conversation = await newConversation(); - setActiveConversation(conversation); - setShowConversationList(false); - conversationCreated = true; - } - - // Update the conversation with the new user message - const updatedMessages = [...conversation.messages, newMessage]; - setActiveConversation({ - ...conversation, - title: "Asking JJ...", - messages: updatedMessages, - }); - setIsLoading(true); + setError(null); - // Continue the conversation and update with the bot's response - const data = await continueConversation( - conversation.id, - newMessage.content, - pageContent + // Add user message to current conversation + const updatedMessages = [...currentMessages, newMessage]; + setCurrentMessages(updatedMessages); + + // Save user message immediately to prevent loss + saveConversationToStorage(updatedMessages, currentResponseId); + + // Call assistant API with previous response ID for continuity + const data = await sendAssistantMessage( + messageContent, + currentResponseId, ); - // Update the ConversationList component after previous function creates a title - if (conversationCreated) loadConversations(); // Note: no 'await' so this can occur in the background - - setActiveConversation((prevConversation: any) => { - if (!prevConversation) return null; - - return { - ...prevConversation, - messages: [ - ...prevConversation.messages, - { - is_user: false, - content: data.response, - timestamp: new Date().toISOString(), - }, - ], - title: data.title, - }; - }); - setError(null); + // Create assistant response message + const assistantMessage = { + content: data.response_output_text, + is_user: false, + timestamp: new Date().toISOString(), + }; + + // Update messages and store new response ID for next message + const finalMessages = [...updatedMessages, assistantMessage]; + setCurrentMessages(finalMessages); + setCurrentResponseId(data.final_response_id); + + // Save conversation to sessionStorage + saveConversationToStorage(finalMessages, data.final_response_id); } catch (error) { - console.error("Error(s) handling conversation:", error); + console.error("Error handling message:", error); let errorMessage = "Error submitting message"; if (error instanceof Error) { errorMessage = error.message; @@ -222,25 +208,8 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { } }; - const handleSelectConversation = (id: Conversation["id"]) => { - const selectedConversation = conversations.find( - (conversation: any) => conversation.id === id - ); - - if (selectedConversation) { - setActiveConversation(selectedConversation); - setShowConversationList(false); - } - }; - - const handleNewConversation = () => { - setActiveConversation(null); - setShowConversationList(false); - }; - useEffect(() => { if (showChat) { - loadConversations(); const resizeObserver = new ResizeObserver((entries) => { if (!entries || entries.length === 0) return; @@ -278,53 +247,35 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { className=" mx-auto flex h-full flex-col overflow-auto rounded " >
- - -
- {activeConversation !== null && !showConversationList ? ( - activeConversation.title - ) : ( - <> - - Ask {CHATBOT_NAME} - - )} -
+
+ + Ask {CHATBOT_NAME}
+ + + - {showConversationList ? ( -
- -
- ) : ( -
- {activeConversation === null || - activeConversation.messages.length === 0 ? ( - <> -
-
Hi there, I'm {CHATBOT_NAME}!
-

- You can ask me all your bipolar disorder treatment - questions. -

- - Learn more about my sources. - -
-
-
- -
Explore a medication
-
-
    - {suggestionPrompts.map((suggestion, index) => ( -
  • - -
  • - ))} -
+
+ {currentMessages.length === 0 ? ( + <> +
+
Hi there, I'm {CHATBOT_NAME}!
+

+ You can ask me questions about your uploaded documents. + I'll search through them to provide accurate, cited + answers. +

+ + Learn more about my sources. + +
+
+
+ ⚠️ IMPORTANT NOTICE +
+

+ Balancer is NOT configured for use with Protected Health Information (PHI) as defined under HIPAA. + You must NOT enter any patient-identifiable information including names, addresses, dates of birth, + medical record numbers, or any other identifying information. Your queries may be processed by + third-party AI services that retain data for up to 30 days for abuse monitoring. By using Balancer, + you certify that you understand these restrictions and will not enter any PHI. +

+
+
+
+ +
Explore a medication
-
-
- -
Refresh your memory
-
-
    - {refreshPrompts.map((suggestion, index) => ( -
  • - -
  • - ))} -
+
    + {suggestionPrompts.map((suggestion, index) => ( +
  • + +
  • + ))} +
+
+
+
+ +
Refresh your memory
- - ) : ( - activeConversation.messages - .slice() - .sort( - (a, b) => - new Date(a.timestamp).getTime() - - new Date(b.timestamp).getTime() - ) - .map((message, index) => ( -
-
- {message.is_user ? ( - - ) : ( - - )} -
-
-
+                        {refreshPrompts.map((suggestion, index) => (
+                          
  • +
  • -
    + {suggestion} + + + ))} + +
    + + ) : ( + currentMessages + .slice() + .sort( + (a, b) => + new Date(a.timestamp).getTime() - + new Date(b.timestamp).getTime(), + ) + .map((message, index) => ( +
    +
    + +
    +
    +
    +                            {message.content}
    +                          
    - )) - )} - {isLoading && ( -
    -
    -
    + )) + )} + {isLoading && ( +
    +
    +
    - )} - {error && } -
    - )} +
    + )} + {error && } +
    = ({ showChat, setShowChat }) => {
    ) : ( -
    setShowChat(true)} - className="chat_button no-print" - > +
    setShowChat(true)} className="chat_button no-print">
    )} diff --git a/frontend/src/services/actions/auth.tsx b/frontend/src/services/actions/auth.tsx index 3a29bc38..2573c223 100644 --- a/frontend/src/services/actions/auth.tsx +++ b/frontend/src/services/actions/auth.tsx @@ -169,6 +169,9 @@ export const login = }; export const logout = () => async (dispatch: AppDispatch) => { + // Clear chat conversation data on logout for security + sessionStorage.removeItem('currentConversation'); + dispatch({ type: LOGOUT, }); diff --git a/server/api/services/conversions_services.py b/server/api/services/conversions_services.py index d134ff49..71931f17 100644 --- a/server/api/services/conversions_services.py +++ b/server/api/services/conversions_services.py @@ -2,6 +2,23 @@ def convert_uuids(data): + """ + Recursively convert UUID objects to strings in nested data structures. + + Traverses dictionaries, lists, and other data structures to find UUID objects + and converts them to their string representation for serialization. + + Parameters + ---------- + data : any + The data structure to process (dict, list, UUID, or any other type) + + Returns + ------- + any + The data structure with all UUID objects converted to strings. + Structure and types are preserved except for UUID -> str conversion. + """ if isinstance(data, dict): return {key: convert_uuids(value) for key, value in data.items()} elif isinstance(data, list): diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py index 5aacab38..6fd34d35 100644 --- a/server/api/services/embedding_services.py +++ b/server/api/services/embedding_services.py @@ -1,29 +1,63 @@ # services/embedding_services.py + +from pgvector.django import L2Distance + from .sentencetTransformer_model import TransformerModel + # Adjust import path as needed from ..models.model_embeddings import Embeddings -from pgvector.django import L2Distance -def get_closest_embeddings(user, message_data, document_name=None, guid=None, num_results=10): +def get_closest_embeddings( + user, message_data, document_name=None, guid=None, num_results=10 +): + """ + Find the closest embeddings to a given message for a specific user. + + Parameters + ---------- + user : User + The user whose uploaded documents will be searched + message_data : str + The input message to find similar embeddings for + document_name : str, optional + Filter results to a specific document name + guid : str, optional + Filter results to a specific document GUID (takes precedence over document_name) + num_results : int, default 10 + Maximum number of results to return + + Returns + ------- + list[dict] + List of dictionaries containing embedding results with keys: + - name: document name + - text: embedded text content + - page_number: page number in source document + - chunk_number: chunk number within the document + - distance: L2 distance from query embedding + - file_id: GUID of the source file + """ + # transformerModel = TransformerModel.get_instance().model embedding_message = transformerModel.encode(message_data) # Start building the query based on the message's embedding - closest_embeddings_query = Embeddings.objects.filter( - upload_file__uploaded_by=user - ).annotate( - distance=L2Distance( - 'embedding_sentence_transformers', embedding_message) - ).order_by('distance') + closest_embeddings_query = ( + Embeddings.objects.filter(upload_file__uploaded_by=user) + .annotate( + distance=L2Distance("embedding_sentence_transformers", embedding_message) + ) + .order_by("distance") + ) # Filter by GUID if provided, otherwise filter by document name if provided if guid: closest_embeddings_query = closest_embeddings_query.filter( - upload_file__guid=guid) + upload_file__guid=guid + ) elif document_name: - closest_embeddings_query = closest_embeddings_query.filter( - name=document_name) + closest_embeddings_query = closest_embeddings_query.filter(name=document_name) # Slice the results to limit to num_results closest_embeddings_query = closest_embeddings_query[:num_results] diff --git a/server/api/views/assistant/urls.py b/server/api/views/assistant/urls.py new file mode 100644 index 00000000..4c68f952 --- /dev/null +++ b/server/api/views/assistant/urls.py @@ -0,0 +1,5 @@ +from django.urls import path + +from .views import Assistant + +urlpatterns = [path("v1/api/assistant", Assistant.as_view(), name="assistant")] diff --git a/server/api/views/assistant/views.py b/server/api/views/assistant/views.py new file mode 100644 index 00000000..ca65f335 --- /dev/null +++ b/server/api/views/assistant/views.py @@ -0,0 +1,324 @@ +import os +import json +import logging +import time +from typing import Callable + +from rest_framework.views import APIView +from rest_framework.response import Response +from rest_framework import status +from rest_framework.permissions import IsAuthenticated +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt + +from openai import OpenAI + +from ...services.embedding_services import get_closest_embeddings +from ...services.conversions_services import convert_uuids + +# Configure logging +logger = logging.getLogger(__name__) + +GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.05, "output": 0.40} + + +def calculate_cost_metrics(token_usage: dict, pricing: dict) -> dict: + """ + Calculate cost metrics based on token usage and pricing + + Args: + token_usage: Dictionary containing input_tokens and output_tokens + pricing: Dictionary containing input and output pricing per million tokens + + Returns: + Dictionary containing input_cost, output_cost, and total_cost in USD + """ + TOKENS_PER_MILLION = 1_000_000 + + # Pricing is in dollars per million tokens + input_cost_dollars = (pricing["input"] / TOKENS_PER_MILLION) * token_usage.get( + "input_tokens", 0 + ) + output_cost_dollars = (pricing["output"] / TOKENS_PER_MILLION) * token_usage.get( + "output_tokens", 0 + ) + total_cost_dollars = input_cost_dollars + output_cost_dollars + + return { + "input_cost": input_cost_dollars, + "output_cost": output_cost_dollars, + "total_cost": total_cost_dollars, + } + + +# Open AI Cookbook: Handling Function Calls with Reasoning Models +# https://cookbook.openai.com/examples/reasoning_function_calls +def invoke_functions_from_response( + response, tool_mapping: dict[str, Callable] +) -> list[dict]: + """Extract all function calls from the response, look up the corresponding tool function(s) and execute them. + (This would be a good place to handle asynchroneous tool calls, or ones that take a while to execute.) + This returns a list of messages to be added to the conversation history. + + Parameters + ---------- + response : OpenAI Response + The response object from OpenAI containing output items that may include function calls + tool_mapping : dict[str, Callable] + A dictionary mapping function names (as strings) to their corresponding Python functions. + Keys should match the function names defined in the tools schema. + + Returns + ------- + list[dict] + List of function call output messages formatted for the OpenAI conversation. + Each message contains: + - type: "function_call_output" + - call_id: The unique identifier for the function call + - output: The result returned by the executed function (string or error message) + """ + intermediate_messages = [] + for response_item in response.output: + if response_item.type == "function_call": + target_tool = tool_mapping.get(response_item.name) + if target_tool: + try: + arguments = json.loads(response_item.arguments) + logger.info( + f"Invoking tool: {response_item.name} with arguments: {arguments}" + ) + tool_output = target_tool(**arguments) + logger.info(f"Tool {response_item.name} completed successfully") + except Exception as e: + msg = f"Error executing function call: {response_item.name}: {e}" + tool_output = msg + logger.error(msg, exc_info=True) + else: + msg = f"ERROR - No tool registered for function call: {response_item.name}" + tool_output = msg + logger.error(msg) + intermediate_messages.append( + { + "type": "function_call_output", + "call_id": response_item.call_id, + "output": tool_output, + } + ) + elif response_item.type == "reasoning": + logger.info(f"Reasoning step: {response_item.summary}") + return intermediate_messages + + +@method_decorator(csrf_exempt, name="dispatch") +class Assistant(APIView): + permission_classes = [IsAuthenticated] + + def post(self, request): + try: + user = request.user + + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + TOOL_DESCRIPTION = """ + Search the user's uploaded documents for information relevant to answering their question. + Call this function when you need to find specific information from the user's documents + to provide an accurate, citation-backed response. Always search before answering questions + about document content. + """ + + TOOL_PROPERTY_DESCRIPTION = """ + A specific search query to find relevant information in the user's documents. + Use keywords, phrases, or questions related to what the user is asking about. + Be specific rather than generic - use terms that would appear in the relevant documents. + """ + + tools = [ + { + "type": "function", + "name": "search_documents", + "description": TOOL_DESCRIPTION, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": TOOL_PROPERTY_DESCRIPTION, + } + }, + "required": ["query"], + }, + } + ] + + def search_documents(query: str, user=user) -> str: + """ + Search through user's uploaded documents using semantic similarity. + + This function performs vector similarity search against the user's document corpus + and returns formatted results with context information for the LLM to use. + + Parameters + ---------- + query : str + The search query string + user : User + The authenticated user whose documents to search + + Returns + ------- + str + Formatted search results containing document excerpts with metadata + + Raises + ------ + Exception + If embedding search fails + """ + + try: + embeddings_results = get_closest_embeddings( + user=user, message_data=query.strip() + ) + embeddings_results = convert_uuids(embeddings_results) + + if not embeddings_results: + return "No relevant documents found for your query. Please try different search terms or upload documents first." + + # Format results with clear structure and metadata + prompt_texts = [ + f"[Document {i + 1} - File: {obj['file_id']}, Name: {obj['name']}, Page: {obj['page_number']}, Chunk: {obj['chunk_number']}, Similarity: {1 - obj['distance']:.3f}]\n{obj['text']}\n[End Document {i + 1}]" + for i, obj in enumerate(embeddings_results) + ] + + return "\n\n".join(prompt_texts) + + except Exception as e: + return f"Error searching documents: {str(e)}. Please try again if the issue persists." + + INSTRUCTIONS = """ + You are an AI assistant that helps users find and understand information about bipolar disorder + from their uploaded bipolar disorder research documents using semantic search. + + SEMANTIC SEARCH STRATEGY: + - Always perform semantic search using the search_documents function when users ask questions + - Use conceptually related terms and synonyms, not just exact keyword matches + - Search for the meaning and context of the user's question, not just literal words + - Consider medical terminology, lay terms, and related conditions when searching + + FUNCTION USAGE: + - When a user asks about information that might be in their documents ALWAYS use the search_documents function first + - Perform semantic searches using concepts, symptoms, treatments, and related terms from the user's question + - Only provide answers based on information found through document searches + + RESPONSE FORMAT: + After gathering information through semantic searches, provide responses that: + 1. Answer the user's question directly using only the found information + 2. Structure responses with clear sections and paragraphs + 3. Include citations using this exact format: ***[Name {name}, Page {page_number}]*** + 4. Only cite information that directly supports your statements + + If no relevant information is found in the documents, clearly state that the information is not available in the uploaded documents. + """ + + MODEL_DEFAULTS = { + "instructions": INSTRUCTIONS, + "model": "gpt-5-nano", # 400,000 token context window + "reasoning": {"effort": "low", "summary": "auto"}, + "tools": tools, + } + + # We fetch a response and then kick off a loop to handle the response + + message = request.data.get("message", None) + previous_response_id = request.data.get("previous_response_id", None) + + # Track total duration and cost metrics + start_time = time.time() + total_token_usage = {"input_tokens": 0, "output_tokens": 0} + + if not previous_response_id: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + **MODEL_DEFAULTS, + ) + else: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + previous_response_id=str(previous_response_id), + **MODEL_DEFAULTS, + ) + + # Accumulate token usage from initial response + if hasattr(response, "usage"): + total_token_usage["input_tokens"] += getattr( + response.usage, "input_tokens", 0 + ) + total_token_usage["output_tokens"] += getattr( + response.usage, "output_tokens", 0 + ) + + # Open AI Cookbook: Handling Function Calls with Reasoning Models + # https://cookbook.openai.com/examples/reasoning_function_calls + while True: + # Mapping of the tool names we tell the model about and the functions that implement them + function_responses = invoke_functions_from_response( + response, tool_mapping={"search_documents": search_documents} + ) + if len(function_responses) == 0: # We're done reasoning + logger.info("Reasoning completed") + final_response_output_text = response.output_text + final_response_id = response.id + logger.info(f"Final response: {final_response_output_text}") + break + else: + logger.info("More reasoning required, continuing...") + response = client.responses.create( + input=function_responses, + previous_response_id=response.id, + **MODEL_DEFAULTS, + ) + # Accumulate token usage from reasoning iterations + if hasattr(response, "usage"): + total_token_usage["input_tokens"] += getattr( + response.usage, "input_tokens", 0 + ) + total_token_usage["output_tokens"] += getattr( + response.usage, "output_tokens", 0 + ) + + # Calculate total duration and cost metrics + total_duration = time.time() - start_time + cost_metrics = calculate_cost_metrics( + total_token_usage, GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS + ) + + # Log cost and duration metrics + logger.info( + f"Request completed: " + f"Duration: {total_duration:.2f}s, " + f"Input tokens: {total_token_usage['input_tokens']}, " + f"Output tokens: {total_token_usage['output_tokens']}, " + f"Total cost: ${cost_metrics['total_cost']:.6f}" + ) + + return Response( + { + "response_output_text": final_response_output_text, + "final_response_id": final_response_id, + }, + status=status.HTTP_200_OK, + ) + + except Exception as e: + logger.error( + f"Unexpected error in Assistant view for user {request.user.id if hasattr(request, 'user') else 'unknown'}: {e}", + exc_info=True, + ) + return Response( + {"error": "An unexpected error occurred. Please try again later."}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/server/balancer_backend/settings.py b/server/balancer_backend/settings.py index 91a53a8b..175ca6ab 100644 --- a/server/balancer_backend/settings.py +++ b/server/balancer_backend/settings.py @@ -29,57 +29,56 @@ # Fetching the value from the environment and splitting to list if necessary. # Fallback to '*' if the environment variable is not set. -ALLOWED_HOSTS = os.environ.get('DJANGO_ALLOWED_HOSTS', '*').split() +ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "*").split() # If the environment variable contains '*', the split method would create a list with an empty string. # So you need to check for this case and adjust accordingly. -if ALLOWED_HOSTS == ['*'] or ALLOWED_HOSTS == ['']: - ALLOWED_HOSTS = ['*'] +if ALLOWED_HOSTS == ["*"] or ALLOWED_HOSTS == [""]: + ALLOWED_HOSTS = ["*"] # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'balancer_backend', - 'api', - 'corsheaders', - 'rest_framework', - 'djoser', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "balancer_backend", + "api", + "corsheaders", + "rest_framework", + "djoser", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'corsheaders.middleware.CorsMiddleware', - + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "corsheaders.middleware.CorsMiddleware", ] -ROOT_URLCONF = 'balancer_backend.urls' +ROOT_URLCONF = "balancer_backend.urls" CORS_ALLOW_ALL_ORIGINS = True TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [os.path.join(BASE_DIR, 'build')], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(BASE_DIR, "build")], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, @@ -89,7 +88,7 @@ # Change this to your desired URL LOGIN_REDIRECT_URL = os.environ.get("LOGIN_REDIRECT_URL") -WSGI_APPLICATION = 'balancer_backend.wsgi.application' +WSGI_APPLICATION = "balancer_backend.wsgi.application" # Database @@ -106,8 +105,8 @@ } } -EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' -EMAIL_HOST = 'smtp.gmail.com' +EMAIL_BACKEND = "django.core.mail.backends.smtp.EmailBackend" +EMAIL_HOST = "smtp.gmail.com" EMAIL_PORT = 587 EMAIL_HOST_USER = os.environ.get("EMAIL_HOST_USER", "") EMAIL_HOST_PASSWORD = os.environ.get("EMAIL_HOST_PASSWORD", "") @@ -119,25 +118,25 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", }, ] # Internationalization # https://docs.djangoproject.com/en/4.2/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -147,64 +146,89 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ -STATIC_URL = '/static/' +STATIC_URL = "/static/" STATICFILES_DIRS = [ - os.path.join(BASE_DIR, 'build/static'), + os.path.join(BASE_DIR, "build/static"), ] -STATIC_ROOT = os.path.join(BASE_DIR, 'static') +STATIC_ROOT = os.path.join(BASE_DIR, "static") AUTHENTICATION_BACKENDS = [ - 'django.contrib.auth.backends.ModelBackend', + "django.contrib.auth.backends.ModelBackend", ] REST_FRAMEWORK = { - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated' - ], - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework_simplejwt.authentication.JWTAuthentication', + "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], + "DEFAULT_AUTHENTICATION_CLASSES": ( + "rest_framework_simplejwt.authentication.JWTAuthentication", ), } SIMPLE_JWT = { - 'AUTH_HEADER_TYPES': ('JWT',), - 'ACCESS_TOKEN_LIFETIME': timedelta(minutes=60), - 'REFRESH_TOKEN_LIFETIME': timedelta(days=1), - 'TOKEN_OBTAIN_SERIALIZER': 'api.models.TokenObtainPairSerializer.MyTokenObtainPairSerializer', - 'AUTH_TOKEN_CLASSES': ( - 'rest_framework_simplejwt.tokens.AccessToken', - ) + "AUTH_HEADER_TYPES": ("JWT",), + "ACCESS_TOKEN_LIFETIME": timedelta(minutes=60), + "REFRESH_TOKEN_LIFETIME": timedelta(days=1), + "TOKEN_OBTAIN_SERIALIZER": "api.models.TokenObtainPairSerializer.MyTokenObtainPairSerializer", + "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",), } DJOSER = { - 'LOGIN_FIELD': 'email', - 'USER_CREATE_PASSWORD_RETYPE': True, - 'USERNAME_CHANGED_EMAIL_CONFIRMATION': True, - 'PASSWORD_CHANGED_EMAIL_CONFIRMATION': True, - 'SEND_CONFIRMATION_EMAIL': True, - 'SET_USERNAME_RETYPE': True, - 'SET_PASSWORD_RETYPE': True, - 'PASSWORD_RESET_CONFIRM_URL': 'password/reset/confirm/{uid}/{token}', - 'USERNAME_RESET_CONFIRM_URL': 'email/reset/confirm/{uid}/{token}', - 'ACTIVATION_URL': 'activate/{uid}/{token}', - 'SEND_ACTIVATION_EMAIL': True, - 'SOCIAL_AUTH_TOKEN_STRATEGY': 'djoser.social.token.jwt.TokenStrategy', - 'SOCIAL_AUTH_ALLOWED_REDIRECT_URIS': ['http://localhost:8000/google', 'http://localhost:8000/facebook'], - 'SERIALIZERS': { - 'user_create': 'api.models.serializers.UserCreateSerializer', - 'user': 'api.models.serializers.UserCreateSerializer', - 'current_user': 'api.models.serializers.UserCreateSerializer', - 'user_delete': 'djoser.serializers.UserDeleteSerializer', - } + "LOGIN_FIELD": "email", + "USER_CREATE_PASSWORD_RETYPE": True, + "USERNAME_CHANGED_EMAIL_CONFIRMATION": True, + "PASSWORD_CHANGED_EMAIL_CONFIRMATION": True, + "SEND_CONFIRMATION_EMAIL": True, + "SET_USERNAME_RETYPE": True, + "SET_PASSWORD_RETYPE": True, + "PASSWORD_RESET_CONFIRM_URL": "password/reset/confirm/{uid}/{token}", + "USERNAME_RESET_CONFIRM_URL": "email/reset/confirm/{uid}/{token}", + "ACTIVATION_URL": "activate/{uid}/{token}", + "SEND_ACTIVATION_EMAIL": True, + "SOCIAL_AUTH_TOKEN_STRATEGY": "djoser.social.token.jwt.TokenStrategy", + "SOCIAL_AUTH_ALLOWED_REDIRECT_URIS": [ + "http://localhost:8000/google", + "http://localhost:8000/facebook", + ], + "SERIALIZERS": { + "user_create": "api.models.serializers.UserCreateSerializer", + "user": "api.models.serializers.UserCreateSerializer", + "current_user": "api.models.serializers.UserCreateSerializer", + "user_delete": "djoser.serializers.UserDeleteSerializer", + }, } # Default primary key field type # https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" + +AUTH_USER_MODEL = "api.UserAccount" -AUTH_USER_MODEL = 'api.UserAccount' +# Logging configuration +LOGGING = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "verbose": { + "format": "{levelname} {asctime} {module} {process:d} {thread:d} {message}", + "style": "{", + }, + "simple": { + "format": "{levelname} {message}", + "style": "{", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "verbose", + }, + }, + "root": { + "handlers": ["console"], + "level": "INFO", + }, +} diff --git a/server/balancer_backend/urls.py b/server/balancer_backend/urls.py index 1c5bad8b..56f307e4 100644 --- a/server/balancer_backend/urls.py +++ b/server/balancer_backend/urls.py @@ -1,6 +1,8 @@ from django.contrib import admin # Import Django's admin interface module + # Import functions for URL routing and including other URL configs from django.urls import path, include, re_path + # Import TemplateView for rendering templates from django.views.generic import TemplateView import importlib # Import the importlib module for dynamic module importing @@ -10,25 +12,37 @@ # Map 'admin/' URL to the Django admin interface path("admin/", admin.site.urls), # Include Djoser's URL patterns under 'auth/' for basic auth - path('auth/', include('djoser.urls')), + path("auth/", include("djoser.urls")), # Include Djoser's JWT auth URL patterns under 'auth/' - path('auth/', include('djoser.urls.jwt')), + path("auth/", include("djoser.urls.jwt")), # Include Djoser's social auth URL patterns under 'auth/' - path('auth/', include('djoser.social.urls')), + path("auth/", include("djoser.social.urls")), ] # List of application names for which URL patterns will be dynamically added -urls = ['conversations', 'feedback', 'listMeds', 'risk', - 'uploadFile', 'ai_promptStorage', 'ai_settings', 'embeddings', 'medRules', 'text_extraction'] +urls = [ + "conversations", + "feedback", + "listMeds", + "risk", + "uploadFile", + "ai_promptStorage", + "ai_settings", + "embeddings", + "medRules", + "text_extraction", + "assistant", +] # Loop through each application name and dynamically import and add its URL patterns for url in urls: # Dynamically import the URL module for each app - url_module = importlib.import_module(f'api.views.{url}.urls') + url_module = importlib.import_module(f"api.views.{url}.urls") # Append the URL patterns from each imported module - urlpatterns += getattr(url_module, 'urlpatterns', []) + urlpatterns += getattr(url_module, "urlpatterns", []) # Add a catch-all URL pattern for handling SPA (Single Page Application) routing # Serve 'index.html' for any unmatched URL urlpatterns += [ - re_path(r'^.*$', TemplateView.as_view(template_name='index.html')),] + re_path(r"^.*$", TemplateView.as_view(template_name="index.html")), +]