Skip to content

Commit 1b91b85

Browse files
Instantiate chat history manager at runtime instead of import-time (#312)
* Call get_chat_history_manager() at runtime * Simplify S3ChatHistoryManager attributes * Facilitate monkey patching * Clarify docstrings * Satisfy mypy
1 parent 0a6b9ce commit 1b91b85

File tree

8 files changed

+40
-31
lines changed

8 files changed

+40
-31
lines changed

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from ....services.chat.suggested_questions import generate_suggested_questions
6060
from ....services.chat_history.chat_history_manager import (
6161
RagStudioChatMessage,
62-
chat_history_manager,
62+
get_chat_history_manager,
6363
)
6464
from ....services.chat_history.paginator import paginate
6565
from ....services.metadata_apis import session_metadata_api
@@ -142,7 +142,7 @@ class RagStudioChatHistoryResponse(BaseModel):
142142
def chat_history(
143143
session_id: int, limit: Optional[int] = None, offset: Optional[int] = None
144144
) -> RagStudioChatHistoryResponse:
145-
results = chat_history_manager.retrieve_chat_history(session_id=session_id)
145+
results = get_chat_history_manager().retrieve_chat_history(session_id=session_id)
146146

147147
paginated_results, previous_id, next_id = paginate(results, limit, offset)
148148
return RagStudioChatHistoryResponse(
@@ -158,8 +158,8 @@ def chat_history(
158158
)
159159
@exceptions.propagates
160160
def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage:
161-
results: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history(
162-
session_id=session_id
161+
results: list[RagStudioChatMessage] = (
162+
get_chat_history_manager().retrieve_chat_history(session_id=session_id)
163163
)
164164
for message in results:
165165
if message.id == message_id:
@@ -175,14 +175,14 @@ def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage:
175175
)
176176
@exceptions.propagates
177177
def clear_chat_history(session_id: int) -> str:
178-
chat_history_manager.clear_chat_history(session_id=session_id)
178+
get_chat_history_manager().clear_chat_history(session_id=session_id)
179179
return "Chat history cleared."
180180

181181

182182
@router.delete("", summary="Deletes the requested session.")
183183
@exceptions.propagates
184184
def delete_session(session_id: int) -> str:
185-
chat_history_manager.delete_chat_history(session_id=session_id)
185+
get_chat_history_manager().delete_chat_history(session_id=session_id)
186186
return "Chat history deleted."
187187

188188

llm-service/app/services/chat/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
Evaluation,
5151
RagMessage,
5252
RagStudioChatMessage,
53-
chat_history_manager,
53+
get_chat_history_manager,
5454
)
5555
from app.services.metadata_apis.session_metadata_api import Session
5656
from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run
@@ -172,7 +172,7 @@ def finalize_response(
172172
record_rag_mlflow_run(
173173
new_chat_message, query_configuration, response_id, session, user_name
174174
)
175-
chat_history_manager.append_to_history(session.id, [new_chat_message])
175+
get_chat_history_manager().append_to_history(session.id, [new_chat_message])
176176

177177
return new_chat_message
178178

@@ -198,5 +198,5 @@ def direct_llm_chat(
198198
timestamp=time.time(),
199199
condensed_question=None,
200200
)
201-
chat_history_manager.append_to_history(session.id, [new_chat_message])
201+
get_chat_history_manager().append_to_history(session.id, [new_chat_message])
202202
return new_chat_message

llm-service/app/services/chat/streaming_chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from app.services.chat_history.chat_history_manager import (
5454
RagStudioChatMessage,
5555
RagMessage,
56-
chat_history_manager,
56+
get_chat_history_manager,
5757
)
5858
from app.services.metadata_apis.session_metadata_api import Session
5959
from app.services.mlflow import record_direct_llm_mlflow_run
@@ -217,4 +217,4 @@ def _stream_direct_llm_chat(
217217
timestamp=time.time(),
218218
condensed_question=None,
219219
)
220-
chat_history_manager.append_to_history(session.id, [new_chat_message])
220+
get_chat_history_manager().append_to_history(session.id, [new_chat_message])

llm-service/app/services/chat/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pydantic import BaseModel
4444

4545
from app.services.chat_history.chat_history_manager import (
46-
chat_history_manager,
46+
get_chat_history_manager,
4747
RagPredictSourceNode,
4848
)
4949

@@ -54,7 +54,7 @@ class RagContext(BaseModel):
5454

5555

5656
def retrieve_chat_history(session_id: int) -> List[RagContext]:
57-
chat_history = chat_history_manager.retrieve_chat_history(session_id)[-10:]
57+
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[-10:]
5858
history: List[RagContext] = []
5959
for message in chat_history:
6060
history.append(

llm-service/app/services/chat_history/chat_history_manager.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
2828
# DATA.
2929
#
30-
30+
import functools
3131
from abc import ABCMeta, abstractmethod
3232
from typing import Optional, Literal
3333

@@ -64,6 +64,9 @@ class RagStudioChatMessage(BaseModel):
6464

6565

6666
class ChatHistoryManager(metaclass=ABCMeta):
67+
def __init__(self) -> None:
68+
pass
69+
6770
@abstractmethod
6871
def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]:
6972
pass
@@ -85,7 +88,13 @@ def append_to_history(
8588
pass
8689

8790

88-
def _create_chat_history_manager() -> ChatHistoryManager:
91+
@functools.cache
92+
def _get_chat_history_manager() -> ChatHistoryManager:
93+
"""Create a ChatHistoryManager the first time this function is called, and return it.
94+
95+
This helper function can be monkey-patched for testing purposes.
96+
97+
"""
8998
from app.services.chat_history.simple_chat_history_manager import (
9099
SimpleChatHistoryManager,
91100
)
@@ -101,4 +110,6 @@ def _create_chat_history_manager() -> ChatHistoryManager:
101110
return SimpleChatHistoryManager()
102111

103112

104-
chat_history_manager = _create_chat_history_manager()
113+
def get_chat_history_manager() -> ChatHistoryManager:
114+
"""Return a ChatHistoryManager based on the app's chat store config."""
115+
return _get_chat_history_manager()

llm-service/app/services/chat_history/s3_chat_history_manager.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
38+
import functools
3939
import json
4040
import logging
41-
from typing import List
41+
from typing import List, cast
4242

4343
import boto3
4444
from boto3 import Session
@@ -57,18 +57,16 @@
5757
class S3ChatHistoryManager(ChatHistoryManager):
5858
"""Chat history manager that uses S3 for storage."""
5959

60-
def __init__(self, bucket_name: str = settings.document_bucket):
61-
self.bucket_name = bucket_name
60+
def __init__(self) -> None:
61+
super().__init__()
62+
self.bucket_name = settings.document_bucket
6263
self.bucket_prefix = settings.document_bucket_prefix
63-
self._s3_client: S3Client | None = None
6464

65-
@property
65+
@functools.cached_property
6666
def s3_client(self) -> S3Client:
6767
"""Lazy initialization of S3 client."""
68-
if self._s3_client is None:
69-
session: Session = boto3.session.Session()
70-
self._s3_client = session.client("s3")
71-
return self._s3_client
68+
session: Session = boto3.session.Session()
69+
return cast(S3Client, session.client("s3"))
7270

7371
def _get_s3_key(self, session_id: int) -> str:
7472
"""Build the S3 key for a session's chat history."""

llm-service/app/services/llm_completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from . import models
4848
from .chat_history.chat_history_manager import (
4949
RagStudioChatMessage,
50-
chat_history_manager,
50+
get_chat_history_manager,
5151
)
5252
from .query.query_configuration import QueryConfiguration
5353

@@ -60,7 +60,7 @@ def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]:
6060

6161
def completion(session_id: int, question: str, model_name: str) -> ChatResponse:
6262
model = models.LLM.get(model_name)
63-
chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10]
63+
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[:10]
6464
messages = list(
6565
itertools.chain.from_iterable(
6666
map(lambda x: make_chat_messages(x), chat_history)
@@ -78,7 +78,7 @@ def stream_completion(
7878
Returns a generator that yields ChatResponse objects as they become available.
7979
"""
8080
model = models.LLM.get(model_name)
81-
chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10]
81+
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[:10]
8282
messages = list(
8383
itertools.chain.from_iterable(
8484
map(lambda x: make_chat_messages(x), chat_history)

llm-service/app/services/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
from . import models
4242
from .chat_history.chat_history_manager import (
43-
chat_history_manager,
43+
get_chat_history_manager,
4444
RagStudioChatMessage,
4545
)
4646
from .metadata_apis import session_metadata_api
@@ -87,7 +87,7 @@
8787

8888
def rename_session(session_id: int, user_name: Optional[str]) -> str:
8989
chat_history: list[RagStudioChatMessage] = (
90-
chat_history_manager.retrieve_chat_history(session_id=session_id)
90+
get_chat_history_manager().retrieve_chat_history(session_id=session_id)
9191
)
9292
if not chat_history:
9393
logger.info("No chat history found for session ID %s", session_id)

0 commit comments

Comments
 (0)