Skip to content

Commit f38b7b5

Browse files
committed
feat: switch validation of request body based on request type
1 parent 884ace8 commit f38b7b5

File tree

3 files changed

+43
-33
lines changed

3 files changed

+43
-33
lines changed

pkg/plugins/gateway/gateway_req_body.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,27 @@ import (
3131
)
3232

3333
func (s *Server) HandleRequestBody(ctx context.Context, requestID string, requestPath string, req *extProcPb.ProcessingRequest,
34-
user utils.User, routingAlgorithm types.RoutingAlgorithm) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) {
34+
user utils.User, routingAlgorithm types.RoutingAlgorithm,
35+
) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, OpenAiRequestType, int64) {
3536
var routingCtx *types.RoutingContext
3637
var term int64 // Identify the trace window
3738

39+
requestType := NewOpenAiRequestTypeFromPath(requestPath)
40+
3841
body := req.Request.(*extProcPb.ProcessingRequest_RequestBody)
39-
model, message, stream, errRes := validateRequestBody(requestID, requestPath, body.RequestBody.GetBody(), user)
42+
model, message, stream, errRes := validateRequestBody(requestID, requestType, body.RequestBody.GetBody(), user)
4043
if errRes != nil {
41-
return errRes, model, routingCtx, stream, term
44+
return errRes, model, routingCtx, stream, requestType, term
4245
}
4346

4447
// early reject the request if model doesn't exist.
4548
if !s.cache.HasModel(model) {
4649
klog.ErrorS(nil, "model doesn't exist in cache, probably wrong model name", "requestID", requestID, "model", model)
4750
return generateErrorResponse(envoyTypePb.StatusCode_BadRequest,
4851
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
49-
Key: HeaderErrorNoModelBackends, RawValue: []byte(model)}}},
50-
fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, term
52+
Key: HeaderErrorNoModelBackends, RawValue: []byte(model),
53+
}}},
54+
fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, requestType, term
5155
}
5256

5357
// early reject if no pods are ready to accept request for a model
@@ -56,8 +60,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
5660
klog.ErrorS(err, "no ready pod available", "requestID", requestID, "model", model)
5761
return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable,
5862
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
59-
Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}},
60-
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term
63+
Key: HeaderErrorNoModelBackends, RawValue: []byte("true"),
64+
}}},
65+
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, requestType, term
6166
}
6267

6368
routingCtx = types.NewRoutingContext(ctx, routingAlgorithm, model, message, requestID, user.Name)
@@ -72,8 +77,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
7277
return generateErrorResponse(
7378
envoyTypePb.StatusCode_ServiceUnavailable,
7479
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
75-
Key: HeaderErrorRouting, RawValue: []byte("true")}}},
76-
"error on selecting target pod"), model, routingCtx, stream, term
80+
Key: HeaderErrorRouting, RawValue: []byte("true"),
81+
}}},
82+
"error on selecting target pod"), model, routingCtx, stream, requestType, term
7783
}
7884
headers = buildEnvoyProxyHeaders(headers,
7985
HeaderRoutingStrategy, string(routingAlgorithm),
@@ -93,5 +99,5 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
9399
},
94100
},
95101
},
96-
}, model, routingCtx, stream, term
102+
}, model, routingCtx, stream, requestType, term
97103
}

pkg/plugins/gateway/util.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ import (
3434
// see https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-input
3535
var maxEmbeddingInputArraySize = 2048
3636

37-
// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestpath.
37+
// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestType.
3838
// nolint:nakedret
39-
func validateRequestBody(requestID, requestPath string, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {
39+
func validateRequestBody(requestID string, requestType OpenAiRequestType, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {
4040
var streamOptions openai.ChatCompletionStreamOptionsParam
41-
if requestPath == "/v1/chat/completions" {
41+
switch requestType {
42+
43+
case OpenAiRequestChatCompletionsType:
4244
var jsonMap map[string]json.RawMessage
4345
if err := json.Unmarshal(requestBody, &jsonMap); err != nil {
4446
klog.ErrorS(err, "error to unmarshal request body", "requestID", requestID, "requestBody", string(requestBody))
@@ -59,7 +61,8 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
5961
if errRes = validateStreamOptions(requestID, user, &stream, streamOptions, jsonMap); errRes != nil {
6062
return
6163
}
62-
} else if requestPath == "/v1/completions" {
64+
65+
case OpenAiRequestCompletionsType:
6366
// openai.CompletionsNewParams does not support json unmarshal for CompletionNewParamsPromptUnion in release v0.1.0-beta.10
6467
// once supported, input request will be directly unmarshal into openai.CompletionsNewParams
6568
type Completion struct {
@@ -75,7 +78,8 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
7578
}
7679
model = completionObj.Model
7780
message = completionObj.Prompt
78-
} else if requestPath == "/v1/embeddings" {
81+
82+
case OpenAiRequestEmbeddingsType:
7983
message = "" // prefix_cache algorithms are not relevant for embeddings
8084
var jsonMap map[string]json.RawMessage
8185
if err := json.Unmarshal(requestBody, &jsonMap); err != nil {
@@ -93,12 +97,12 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
9397
if errRes = checkEmbeddingInputSequenceLen(requestID, embeddingObj); errRes != nil {
9498
return
9599
}
96-
} else {
100+
case OpenAiRequestUnknownType:
97101
errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true")
98102
return
99103
}
100104

101-
klog.V(4).InfoS("validateRequestBody", "requestID", requestID, "requestPath", requestPath, "model", model, "message", message, "stream", stream, "streamOptions", streamOptions)
105+
klog.V(4).InfoS("validateRequestBody", "requestID", requestID, "requestType", requestType, "model", model, "message", message, "stream", stream, "streamOptions", streamOptions)
102106
return
103107
}
104108

pkg/plugins/gateway/util_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
func Test_ValidateRequestBody(t *testing.T) {
2929
testCases := []struct {
3030
message string
31-
requestPath string
31+
requestType OpenAiRequestType
3232
requestBody []byte
3333
model string
3434
messages string
@@ -38,95 +38,95 @@ func Test_ValidateRequestBody(t *testing.T) {
3838
}{
3939
{
4040
message: "unknown path",
41-
requestPath: "/v1/unknown",
41+
requestType: OpenAiRequestUnknownType,
4242
statusCode: envoyTypePb.StatusCode_NotImplemented,
4343
},
4444
{
4545
message: "/v1/chat/completions json unmarhsal error",
46-
requestPath: "/v1/chat/completions",
46+
requestType: OpenAiRequestChatCompletionsType,
4747
requestBody: []byte("bad_request"),
4848
statusCode: envoyTypePb.StatusCode_BadRequest,
4949
},
5050
{
5151
message: "/v1/chat/completions json unmarhsal ChatCompletionsNewParams",
52-
requestPath: "/v1/chat/completions",
52+
requestType: OpenAiRequestChatCompletionsType,
5353
requestBody: []byte(`{"model": 1}`),
5454
statusCode: envoyTypePb.StatusCode_BadRequest,
5555
},
5656
{
5757
message: "/v1/chat/completions json unmarhsal no messages",
58-
requestPath: "/v1/chat/completions",
58+
requestType: OpenAiRequestChatCompletionsType,
5959
requestBody: []byte(`{"model": "llama2-7b"}`),
6060
statusCode: envoyTypePb.StatusCode_BadRequest,
6161
},
6262
{
6363
message: "/v1/chat/completions json unmarhsal valid messages",
64-
requestPath: "/v1/chat/completions",
64+
requestType: OpenAiRequestChatCompletionsType,
6565
requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}]}`),
6666
model: "llama2-7b",
6767
messages: "this is system say this is test",
6868
statusCode: envoyTypePb.StatusCode_OK,
6969
},
7070
{
7171
message: "/v1/chat/completions json unmarhsal invalid messages with complex content",
72-
requestPath: "/v1/chat/completions",
72+
requestType: OpenAiRequestChatCompletionsType,
7373
requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": {"type": "text", "text": "say this is test", "complex": make(chan int)}}]}`),
7474
statusCode: envoyTypePb.StatusCode_BadRequest,
7575
},
7676
{
7777
message: "/v1/chat/completions json unmarhsal valid messages with complex content",
78-
requestPath: "/v1/chat/completions",
78+
requestType: OpenAiRequestChatCompletionsType,
7979
requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": [{"type": "text", "text": "say this is test"}, {"type": "text", "text": "say this is test"}]}]}`),
8080
model: "llama2-7b",
8181
messages: "this is system [{\"text\":\"say this is test\",\"type\":\"text\"},{\"text\":\"say this is test\",\"type\":\"text\"}]",
8282
statusCode: envoyTypePb.StatusCode_OK,
8383
},
8484
{
8585
message: "/v1/chat/completions json unmarhsal valid messages with stop string param",
86-
requestPath: "/v1/chat/completions",
86+
requestType: OpenAiRequestChatCompletionsType,
8787
requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}], "stop": "stop"}`),
8888
model: "llama2-7b",
8989
messages: "this is system say this is test",
9090
statusCode: envoyTypePb.StatusCode_OK,
9191
},
9292
{
9393
message: "/v1/chat/completions json unmarhsal valid messages with stop array param",
94-
requestPath: "/v1/chat/completions",
94+
requestType: OpenAiRequestChatCompletionsType,
9595
requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}], "stop": ["stop"]}`),
9696
model: "llama2-7b",
9797
messages: "this is system say this is test",
9898
statusCode: envoyTypePb.StatusCode_OK,
9999
},
100100
{
101101
message: "/v1/chat/completions json unmarshal invalid stream bool",
102-
requestPath: "/v1/chat/completions",
102+
requestType: OpenAiRequestChatCompletionsType,
103103
requestBody: []byte(`{"model": "llama2-7b", "stream": "true", "messages": [{"role": "system", "content": "this is system"}]}`),
104104
statusCode: envoyTypePb.StatusCode_BadRequest,
105105
},
106106
{
107107
message: "/v1/chat/completions json unmarshal stream options is null",
108-
requestPath: "/v1/chat/completions",
108+
requestType: OpenAiRequestChatCompletionsType,
109109
user: utils.User{Tpm: 1},
110110
requestBody: []byte(`{"model": "llama2-7b", "stream": true, "messages": [{"role": "system", "content": "this is system"}]}`),
111111
statusCode: envoyTypePb.StatusCode_BadRequest,
112112
},
113113
{
114114
message: "/v1/chat/completions stream_options.include_usage == false with user.TPM >= 1 is NOT OK",
115115
user: utils.User{Tpm: 1},
116-
requestPath: "/v1/chat/completions",
116+
requestType: OpenAiRequestChatCompletionsType,
117117
requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": false}, "messages": [{"role": "system", "content": "this is system"}]}`),
118118
statusCode: envoyTypePb.StatusCode_BadRequest,
119119
},
120120
{
121121
message: "/v1/chat/completions stream_options.include_usage == false with user.TPM == 0 is OK",
122-
requestPath: "/v1/chat/completions",
122+
requestType: OpenAiRequestChatCompletionsType,
123123
requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": false}, "messages": [{"role": "system", "content": "this is system"}]}`),
124124
statusCode: envoyTypePb.StatusCode_OK,
125125
},
126126
{
127127
message: "/v1/chat/completions valid request body",
128128
user: utils.User{Tpm: 1},
129-
requestPath: "/v1/chat/completions",
129+
requestType: OpenAiRequestChatCompletionsType,
130130
requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": true}, "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}]}`),
131131
stream: true,
132132
model: "llama2-7b",
@@ -136,7 +136,7 @@ func Test_ValidateRequestBody(t *testing.T) {
136136
}
137137

138138
for _, tt := range testCases {
139-
model, messages, stream, errRes := validateRequestBody("1", tt.requestPath, tt.requestBody, tt.user)
139+
model, messages, stream, errRes := validateRequestBody("1", tt.requestType, tt.requestBody, tt.user)
140140

141141
if tt.statusCode == 200 {
142142
assert.Equal(t, (*extProcPb.ProcessingResponse)(nil), errRes, tt.message)

0 commit comments

Comments
 (0)