Skip to content

Commit 1cd71d2

Browse files
Enable kb-free chat sessions (#156)
* update python and java to allow for no kb * "start enabling no-kb chats" * wip on enabling ui to allow for chatting against no kb * "wip ui" * "wip on no-kb-chats" * "add kb form item to chat settings" * "formatting nodes wip" * "add dsid to node" * "invalidate suggested questions on session update" * "new prompt for suggested q" * "do suggested questions on a no-kb chat" * fix azure env vars * minor ui changes, adding claude 3.7 * remove claude 3.7 * "wip on empty chat cleanup" * remove errant `bin` directory * wip on fixing tests * fix broken ui tests * simplify mocking * move vi mocks up to the top level so we don't fool ourselves into in thinking they are scoped --------- Co-authored-by: Elijah Williams <[email protected]>
1 parent 16b215a commit 1cd71d2

File tree

23 files changed

+231
-93
lines changed

23 files changed

+231
-93
lines changed

backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ public void delete(Long id) {
194194
public void update(Types.Session input) {
195195
var updatedInput = input.withTimeUpdated(Instant.now());
196196
String json = serializeQueryConfiguration(input);
197-
jdbi.useHandle(
197+
jdbi.useTransaction(
198198
handle -> {
199199
var sql =
200200
"""
@@ -208,6 +208,11 @@ public void update(Types.Session input) {
208208
.bind("queryConfiguration", json)
209209
.bindMethods(updatedInput)
210210
.execute();
211+
handle
212+
.createUpdate("DELETE FROM CHAT_SESSION_DATA_SOURCE WHERE CHAT_SESSION_ID = :id")
213+
.bind("id", input.id())
214+
.execute();
215+
insertSessionDataSources(handle, input.id(), input.dataSourceIds());
211216
});
212217
}
213218

backend/src/test/java/com/cloudera/cai/rag/TestData.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,14 @@ public static Types.Session createTestSessionInstance(String sessionName) {
6161
}
6262

6363
public static Types.CreateSession createSessionInstance(String sessionName) {
64+
return createSessionInstance(sessionName, List.of(1L, 2L, 3L));
65+
}
66+
67+
public static Types.CreateSession createSessionInstance(
68+
String sessionName, List<Long> dataSourceIds) {
6469
return new Types.CreateSession(
6570
sessionName,
66-
List.of(1L, 2L, 3L),
71+
dataSourceIds,
6772
"test-model",
6873
"test-rerank-model",
6974
3,

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import com.cloudera.cai.rag.util.UserTokenCookieDecoderTest;
4747
import com.cloudera.cai.util.exceptions.NotFound;
4848
import com.fasterxml.jackson.core.JsonProcessingException;
49+
import java.util.List;
4950
import org.junit.jupiter.api.Test;
5051
import org.springframework.mock.web.MockCookie;
5152
import org.springframework.mock.web.MockHttpServletRequest;
@@ -74,6 +75,29 @@ void create() throws JsonProcessingException {
7475
assertThat(result.queryConfiguration()).isNotNull();
7576
}
7677

78+
@Test
79+
void create_noDataSource() throws JsonProcessingException {
80+
SessionController sessionController = new SessionController(SessionService.createNull());
81+
var request = new MockHttpServletRequest();
82+
request.setCookies(
83+
new MockCookie("_basusertoken", UserTokenCookieDecoderTest.encodeCookie("test-user")));
84+
var sessionName = "test";
85+
Types.CreateSession input = TestData.createSessionInstance(sessionName, List.of());
86+
Types.Session result = sessionController.create(input, request);
87+
assertThat(result.id()).isNotNull();
88+
assertThat(result.name()).isEqualTo(sessionName);
89+
assertThat(result.inferenceModel()).isEqualTo(input.inferenceModel());
90+
assertThat(result.rerankModel()).isEqualTo(input.rerankModel());
91+
assertThat(result.responseChunks()).isEqualTo(input.responseChunks());
92+
assertThat(result.dataSourceIds()).isEmpty();
93+
assertThat(result.timeCreated()).isNotNull();
94+
assertThat(result.timeUpdated()).isNotNull();
95+
assertThat(result.createdById()).isEqualTo("test-user");
96+
assertThat(result.updatedById()).isEqualTo("test-user");
97+
assertThat(result.lastInteractionTime()).isNull();
98+
assertThat(result.queryConfiguration()).isNotNull();
99+
}
100+
77101
@Test
78102
void get() {
79103
SessionController sessionController = new SessionController(SessionService.createNull());

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionServiceTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import com.cloudera.cai.rag.TestData;
4444
import com.cloudera.cai.rag.Types;
45+
import java.util.List;
4546
import org.junit.jupiter.api.Test;
4647

4748
class SessionServiceTest {
@@ -77,9 +78,10 @@ void update() {
7778
TestData.createTestSessionInstance("test")
7879
.withCreatedById("abc")
7980
.withUpdatedById("abc"));
80-
var updated = result.withRerankModel("");
81+
var updated = result.withRerankModel("").withDataSourceIds(List.of(4L));
8182
var updatedResult = sessionService.update(updated);
8283
assertThat(updatedResult.rerankModel()).isNull();
84+
assertThat(updatedResult.dataSourceIds()).containsExactly(4L);
8385
}
8486

8587
@Test

cacerts.jks

3.38 KB
Binary file not shown.

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from ....rag_types import RagPredictConfiguration
4848
from ....services.chat import generate_suggested_questions, v2_chat, direct_llm_chat
4949
from ....services.chat_store import ChatHistoryManager, RagStudioChatMessage
50+
from ....services.metadata_apis import session_metadata_api
5051
from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table
5152

5253
logger = logging.getLogger(__name__)
@@ -141,11 +142,12 @@ def chat(
141142
_basusertoken: Annotated[str | None, Cookie()] = None,
142143
) -> RagStudioChatMessage:
143144
user_name = parse_jwt_cookie(_basusertoken)
145+
session = session_metadata_api.get_session(session_id)
144146

145147
configuration = request.configuration or RagPredictConfiguration()
146-
if configuration.exclude_knowledge_base:
147-
return direct_llm_chat(session_id, request.query, user_name)
148-
return v2_chat(session_id, request.query, configuration, user_name)
148+
if configuration.exclude_knowledge_base or len(session.data_source_ids) == 0:
149+
return direct_llm_chat(session, request.query, user_name)
150+
return v2_chat(session, request.query, configuration, user_name)
149151

150152

151153
class RagSuggestedQuestionsResponse(BaseModel):

llm-service/app/services/chat.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@
6262

6363

6464
def v2_chat(
65-
session_id: int, query: str, configuration: RagPredictConfiguration, user_name: str
65+
session: Session, query: str, configuration: RagPredictConfiguration, user_name: str
6666
) -> RagStudioChatMessage:
67-
session = session_metadata_api.get_session(session_id)
6867
query_configuration = QueryConfiguration(
6968
top_k=session.response_chunks,
7069
model_name=session.inference_model,
@@ -80,7 +79,7 @@ def v2_chat(
8079
session, response_id, query, query_configuration, user_name
8180
)
8281

83-
ChatHistoryManager().append_to_history(session_id, [new_chat_message])
82+
ChatHistoryManager().append_to_history(session.id, [new_chat_message])
8483
return new_chat_message
8584

8685

@@ -121,7 +120,7 @@ def _run_chat(
121120
relevance, faithfulness = evaluators.evaluate_response(
122121
query, response, session.inference_model
123122
)
124-
response_source_nodes = format_source_nodes(response)
123+
response_source_nodes = format_source_nodes(response, data_source_id)
125124
new_chat_message = RagStudioChatMessage(
126125
id=response_id,
127126
source_nodes=response_source_nodes,
@@ -159,7 +158,9 @@ def retrieve_chat_history(session_id: int) -> List[RagContext]:
159158
return history
160159

161160

162-
def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNode]:
161+
def format_source_nodes(
162+
response: AgentChatResponse, data_source_id: int
163+
) -> List[RagPredictSourceNode]:
163164
response_source_nodes = []
164165
for source_node in response.source_nodes:
165166
doc_id = source_node.node.metadata.get("document_id", source_node.node.node_id)
@@ -169,6 +170,7 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod
169170
doc_id=doc_id,
170171
source_file_name=source_node.node.metadata["file_name"],
171172
score=source_node.score or 0.0,
173+
dataSourceId=data_source_id,
172174
)
173175
)
174176
response_source_nodes = sorted(
@@ -177,10 +179,32 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod
177179
return response_source_nodes
178180

179181

182+
def generate_suggested_questions_direct_llm(session: Session) -> List[str]:
183+
chat_history = retrieve_chat_history(session.id)
184+
if not chat_history:
185+
return []
186+
query_str = (
187+
" Give me a list of possible follow-up questions."
188+
" Each question should be on a new line."
189+
" There should be no more than four (4) questions."
190+
" Each question should be no longer than fifteen (15) words."
191+
" The response should be a bulleted list, using an asterisk (*) to denote the bullet item."
192+
" Do not start like this - `Here are four questions that I can answer based on the context information`"
193+
" Only return the list."
194+
)
195+
chat_response = llm_completion.completion(
196+
session.id, query_str, session.inference_model
197+
)
198+
suggested_questions = process_response(chat_response.message.content)
199+
return suggested_questions
200+
201+
180202
def generate_suggested_questions(
181203
session_id: int,
182204
) -> List[str]:
183205
session = session_metadata_api.get_session(session_id)
206+
if len(session.data_source_ids) == 0:
207+
return generate_suggested_questions_direct_llm(session)
184208
if len(session.data_source_ids) != 1:
185209
raise HTTPException(
186210
status_code=400,
@@ -256,14 +280,13 @@ def process_response(response: str | None) -> list[str]:
256280

257281

258282
def direct_llm_chat(
259-
session_id: int, query: str, user_name: str
283+
session: Session, query: str, user_name: str
260284
) -> RagStudioChatMessage:
261-
session = session_metadata_api.get_session(session_id)
262285
response_id = str(uuid.uuid4())
263286
record_direct_llm_mlflow_run(response_id, session, user_name)
264287

265288
chat_response = llm_completion.completion(
266-
session_id, query, session.inference_model
289+
session.id, query, session.inference_model
267290
)
268291
new_chat_message = RagStudioChatMessage(
269292
id=response_id,
@@ -277,5 +300,5 @@ def direct_llm_chat(
277300
timestamp=time.time(),
278301
condensed_question=None,
279302
)
280-
ChatHistoryManager().append_to_history(session_id, [new_chat_message])
303+
ChatHistoryManager().append_to_history(session.id, [new_chat_message])
281304
return new_chat_message

llm-service/app/services/chat_store.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class RagPredictSourceNode(BaseModel):
5151
doc_id: str
5252
source_file_name: str
5353
score: float
54+
dataSourceId: Optional[int] = None
5455

5556

5657
class Evaluation(BaseModel):
@@ -119,7 +120,7 @@ def retrieve_chat_history(self, session_id: int) -> List[RagStudioChatMessage]:
119120
"evaluations", []
120121
),
121122
timestamp=assistant_message.additional_kwargs.get("timestamp", 0.0),
122-
condensed_question=None
123+
condensed_question=None,
123124
)
124125
)
125126
i += 2

llm-service/app/services/models/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@
6666
from ..query.simple_reranker import SimpleReranker
6767

6868
__all__ = [
69-
'CAIIModelProvider',
70-
'ModelType',
71-
'Embedding',
72-
'LLM',
73-
'Reranking',
74-
'ModelSource',
75-
'BedrockModelProvider'
69+
"CAIIModelProvider",
70+
"ModelType",
71+
"Embedding",
72+
"LLM",
73+
"Reranking",
74+
"ModelSource",
75+
"BedrockModelProvider",
7676
]
7777

7878
T = TypeVar("T", bound=BaseComponent)

llm-service/app/services/models/_azure.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
class AzureModelProvider(ModelProvider):
4646
@staticmethod
4747
def get_env_var_names() -> set[str]:
48-
return {"AZURE_OPENAI_API_KEY" "AZURE_OPENAI_ENDPOINT" "OPENAI_API_VERSION"}
49-
48+
return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"}
5049
@staticmethod
5150
def get_llm_models() -> List[ModelResponse]:
5251
return [

0 commit comments

Comments
 (0)