Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions crates/common/src/api/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl From<String> for ParameterType {
_ => {
log::warn!("Unknown parameter type: {}, assuming type str", s);
ParameterType::String
},
}
}
}
}
Expand Down Expand Up @@ -205,13 +205,6 @@ pub struct ToolCallState {
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum ModelServerResponse {
ChatCompletionsResponse(ChatCompletionsResponse),
ModelServerErrorResponse(ModelServerErrorResponse),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelServerErrorResponse {
pub result: String,
Expand Down
180 changes: 112 additions & 68 deletions crates/prompt_gateway/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::metrics::Metrics;
use crate::tools::compute_request_path_body;
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::consts::{
Expand Down Expand Up @@ -128,7 +128,7 @@ impl StreamContext {
debug!("model server response received");
trace!("response body: {}", body_str);

let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
Expand All @@ -139,77 +139,121 @@ impl StreamContext {
}
};

let arch_fc_response = match model_server_response {
ModelServerResponse::ChatCompletionsResponse(response) => response,
ModelServerResponse::ModelServerErrorResponse(response) => {
debug!("archgw <= modelserver error response: {}", response.result);
if response.result == "No intent matched" {
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));

let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();

let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];

if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
// intent was matched if we see function_latency in metadata
let intent_matched = model_server_response
.metadata
.as_ref()
.and_then(|metadata| metadata.get("function_latency"))
.is_some();

if !intent_matched {
debug!("intent not matched");
// check if we have a default prompt target
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));

let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();

let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];

// if self.trace_arch_internal() && self.traceparent.is_some() {
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
// }

let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name =
Some(default_prompt_target.name.clone());

if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}

let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());

if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
} else {
debug!("no default prompt target found, forwarding request to upstream llm");
let mut messages = Vec::new();
// add system prompt
match self.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
}
return self.send_server_error(
ServerError::LogicError(response.result),
Some(StatusCode::BAD_REQUEST),

messages.append(
&mut self
.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);

let chat_completion_request = ChatCompletionsRequest {
model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};

let chat_completion_request_json =
serde_json::to_string(&chat_completion_request).unwrap();
debug!(
"archgw => upstream llm request: {}",
chat_completion_request_json
);
self.set_http_request_body(
0,
self.request_body_size,
chat_completion_request_json.as_bytes(),
);
self.resume_http_request();
return;
}
};
}

arch_fc_response.choices[0]
model_server_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
Expand Down Expand Up @@ -238,7 +282,7 @@ impl StreamContext {
),
ChatCompletionStreamResponse::new(
Some(
arch_fc_response.choices[0]
model_server_response.choices[0]
.message
.content
.as_ref()
Expand Down
Loading