Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions safe_infer_chatbot_app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
streamlit>=1.28.0
requests>=2.31.0
aiohttp>=3.8.0
openai>=1.60.0
180 changes: 39 additions & 141 deletions safe_infer_chatbot_app/safe_infer_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import json
from typing import Dict, Any
import time
from openai import OpenAI



# Page configuration
st.set_page_config(
page_title="SafeInfer LLM Chatbot",
page_title="Finance Ops Chatbot",
page_icon="🛡️",
layout="wide",
initial_sidebar_state="expanded"
Expand Down Expand Up @@ -59,14 +60,17 @@
</style>
""", unsafe_allow_html=True)

from utils import get_available_models
from utils import convert_to_dict, get_available_models

# API Configuration
API_KEY = os.getenv("PEBBLO_API_KEY", "")
API_BASE_URL = os.getenv("PROXIMA_HOST", "http://localhost")
RESPONSE_API_ENDPOINT = f"{API_BASE_URL}/safe_infer/llm/v1/responses"
USER_EMAIL = os.getenv("USER_EMAIL", "User")
USER_TEAM = os.getenv("USER_TEAM", "Finance Ops")
RESPONSE_API_ENDPOINT = f"{API_BASE_URL}/safe_infer/llm/v1/"
LLM_PROVIDER_API_ENDPOINT = f"{API_BASE_URL}/api/llm/provider"
AVAILABLE_MODELS, DEFAULT_MODEL = get_available_models()
AVAILABLE_MODELS.append("None")

# Initialize session state
if 'chat_history' not in st.session_state:
Expand All @@ -89,75 +93,21 @@ def test_api_connection() -> Dict[str, Any]:
except Exception as e:
return {"status": "error", "message": f"Error: {str(e)}"}

def call_safe_infer_api(message: str, model: str, api_key: str = "") -> Dict[str, Any]:
"""Call the SafeInfer API"""
headers = {
"Content-Type": "application/json"
}

if api_key:
headers["Authorization"] = f"Bearer {api_key}"

payload = {
"model": model,
"input": message
}

def call_open_ai(message: str, model: str, api_key: str = "") -> Dict[str, Any]:
try:
response = requests.post(
RESPONSE_API_ENDPOINT,
json=payload,
headers=headers,
timeout=30
client = OpenAI(
base_url=RESPONSE_API_ENDPOINT,
api_key=api_key
)

if response.status_code == 200:
return {"status": "success", "data": response.json()}
else:
return {
"status": "error",
"message": f"API Error {response.status_code}: {response.text}"
}
except requests.exceptions.Timeout:
return {"status": "error", "message": "Request timed out"}
except requests.exceptions.ConnectionError:
return {"status": "error", "message": "Cannot connect to API"}
except Exception as e:
return {"status": "error", "message": f"Error: {str(e)}"}

def extract_response_content(api_response: Dict[str, Any]) -> str:
"""Extract the response content from the API response"""
try:
# Handle different response formats
if 'response' in api_response:
response_data = api_response['response']
if isinstance(response_data, dict):
if 'message' in response_data:
if isinstance(response_data['message'], str):
return response_data['message']
elif isinstance(response_data['message'], dict) and 'content' in response_data['message']:
return response_data['message']['content']
elif 'content' in response_data:
return response_data['content']
elif isinstance(response_data, str):
return response_data

# Check for direct content
elif 'content' in api_response:
return api_response['content']

# Check for message field
elif 'message' in api_response:
if isinstance(api_response['message'], str):
return api_response['message']
elif isinstance(api_response['message'], dict) and 'content' in api_response['message']:
return api_response['message']['content']

# If none of the above, return the full response as JSON
return json.dumps(api_response, indent=2)
response = client.responses.create(
model=model,
input=message
)

return {"status": "success", "data": convert_to_dict(response)}
except Exception as e:
return f"Error parsing response: {str(e)}"
return {"status": "error", "message": f"Error: {str(e)}"}

def display_chat_message(role: str, content: str, model: str = "", timestamp: str = ""):
"""Display a chat message with proper styling"""
Expand All @@ -182,8 +132,8 @@ def display_chat_message(role: str, content: str, model: str = "", timestamp: st
# Main header
st.markdown("""
<div class="main-header">
<h1>🛡️ SafeInfer LLM Chatbot</h1>
<p>Secure and intelligent conversations powered by SafeInfer API</p>
<h1>🛡️ Finance Ops Chatbot</h1>
<p>Helpful assistant for Finance Ops team</p>
</div>
""", unsafe_allow_html=True)

Expand Down Expand Up @@ -237,6 +187,15 @@ def display_chat_message(role: str, content: str, model: str = "", timestamp: st
st.metric("Messages", len(st.session_state.chat_history))
st.metric("Current Model", st.session_state.selected_model)

# Welcome message
st.markdown(f"""
<div class="chat-message bot-message">
<strong>🤖 AI Assistant:</strong><br>
Welcome {USER_EMAIL}. {USER_TEAM} team!
</div>
""", unsafe_allow_html=True)


# Main chat interface
st.subheader("💬 Chat Interface")

Expand All @@ -261,73 +220,6 @@ def display_chat_message(role: str, content: str, model: str = "", timestamp: st
col1, col2 = st.columns([1, 4])
with col1:
send_button = st.button("🚀 Send", type="primary")
with col2:
regenerate_button = st.button("🔄 Regenerate Last Response")
if regenerate_button and st.session_state.chat_history:
# Remove the last bot response and regenerate
while st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "assistant":
st.session_state.chat_history.pop()
if st.session_state.chat_history:
# Store the last user message for regeneration
last_user_message = st.session_state.chat_history[-1]["content"]
# Process the regeneration
if last_user_message.strip():
# Add user message to history
st.session_state.chat_history.append({
"role": "user",
"content": last_user_message,
"timestamp": time.strftime("%H:%M:%S")
})

# Display user message
display_chat_message("user", last_user_message)

# Get AI response
with st.spinner("🤖 AI is thinking..."):
result = call_safe_infer_api(
message=last_user_message,
model=st.session_state.selected_model,
api_key=st.session_state.api_key
)

if result["status"] == "success":
# Extract response content
response_content = extract_response_content(result["data"])

# Add bot response to history
st.session_state.chat_history.append({
"role": "assistant",
"content": response_content,
"model": st.session_state.selected_model,
"timestamp": time.strftime("%H:%M:%S")
})

# Display bot response
display_chat_message(
"assistant",
response_content,
st.session_state.selected_model,
time.strftime("%H:%M:%S")
)

# Show classification info if available
if 'response' in result["data"] and isinstance(result["data"]["response"], dict):
response_data = result["data"]["response"]
if 'classification' in response_data:
classification = response_data['classification']
with st.expander("🔍 Response Analysis"):
st.json(classification)

else:
error_message = f"❌ Error: {result['message']}"
st.error(error_message)
st.session_state.chat_history.append({
"role": "assistant",
"content": error_message,
"timestamp": time.strftime("%H:%M:%S")
})

st.rerun()

# Process user input
if send_button and user_input.strip():
Expand All @@ -343,16 +235,22 @@ def display_chat_message(role: str, content: str, model: str = "", timestamp: st

# Get AI response
with st.spinner("🤖 AI is thinking..."):
result = call_safe_infer_api(
if st.session_state.selected_model == "None":
model = DEFAULT_MODEL
else:
model = st.session_state.selected_model
result = call_open_ai(
message=user_input,
model=st.session_state.selected_model,
model=model,
api_key=st.session_state.api_key
)

result = {"status": "success", "data": result}
if result["status"] == "success":
# Extract response content
response_content = extract_response_content(result["data"])

response = result['data']
response_content = response.get('data', {}).get('response', {}).get('message', '')
if 'content' in response_content:
response_content = response_content['content']
# Add bot response to history
st.session_state.chat_history.append({
"role": "assistant",
Expand Down
56 changes: 56 additions & 0 deletions safe_infer_chatbot_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,62 @@

import requests

def convert_to_dict(obj):
"""
Recursively convert any object (including nested dictionaries, objects, etc.)
to a regular Python dictionary.
"""
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
return obj
elif isinstance(obj, dict):
return {key: convert_to_dict(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [convert_to_dict(item) for item in obj]
elif hasattr(obj, 'model_dump'):
# For Pydantic models (try this first as it's most reliable)
try:
return convert_to_dict(obj.model_dump())
except Exception as e:
print(f"Error with model_dump: {e}")
elif hasattr(obj, '_asdict'):
# For namedtuples
try:
return convert_to_dict(obj._asdict())
except Exception as e:
print(f"Error with _asdict: {e}")
elif hasattr(obj, 'dict'):
# For some other object types that have dict() method
try:
return convert_to_dict(obj.dict())
except Exception as e:
print(f"Error with dict(): {e}")
elif hasattr(obj, '__dict__'):
# For objects with __dict__ attribute
try:
return convert_to_dict(obj.__dict__)
except Exception as e:
print(f"Error with __dict__: {e}")
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes)):
# For other iterable objects
try:
return [convert_to_dict(item) for item in obj]
except Exception as e:
print(f"Error with iteration: {e}")
else:
# For objects that can't be converted, try to get their string representation
try:
# Try to get all attributes
if hasattr(obj, '__slots__'):
# For objects with __slots__
return {slot: convert_to_dict(getattr(obj, slot, None)) for slot in obj.__slots__}
else:
# For other objects, try to convert to string or return type info
return str(obj)
except Exception as e:
print(f"Error converting object {type(obj)}: {e}")
return f"<{type(obj).__name__} object>"

def get_available_models():
try:
Expand Down