Skip to content

Commit c13345a

Browse files
refactor: cohere provider refactors
1 parent 5fe52fe commit c13345a

File tree

9 files changed

+146
-153
lines changed

9 files changed

+146
-153
lines changed

core/changelog.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- feat: Adds support for models list for providers
4+
- feat: Use all keys for list models request
5+
- refactor: Cohere provider to use completeRequest and response pooling for all requests
6+
- chore: Added id, object, and model fields to Chat Completion responses from Bedrock and Cohere providers

core/providers/cohere.go

Lines changed: 129 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ var cohereResponsePool = sync.Pool{
3030
},
3131
}
3232

33+
// cohereEmbeddingResponsePool provides a pool for Cohere embedding response objects.
34+
var cohereEmbeddingResponsePool = sync.Pool{
35+
New: func() interface{} {
36+
return &cohere.CohereEmbeddingResponse{}
37+
},
38+
}
39+
40+
// acquireCohereEmbeddingResponse gets a Cohere embedding response from the pool and resets it.
41+
func acquireCohereEmbeddingResponse() *cohere.CohereEmbeddingResponse {
42+
resp := cohereEmbeddingResponsePool.Get().(*cohere.CohereEmbeddingResponse)
43+
*resp = cohere.CohereEmbeddingResponse{} // Reset the struct
44+
return resp
45+
}
46+
47+
// releaseCohereEmbeddingResponse returns a Cohere embedding response to the pool.
48+
func releaseCohereEmbeddingResponse(resp *cohere.CohereEmbeddingResponse) {
49+
if resp != nil {
50+
cohereEmbeddingResponsePool.Put(resp)
51+
}
52+
}
53+
3354
// acquireCohereResponse gets a Cohere v2 response from the pool and resets it.
3455
func acquireCohereResponse() *cohere.CohereChatResponse {
3556
resp := cohereResponsePool.Get().(*cohere.CohereChatResponse)
@@ -74,6 +95,7 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *C
7495
// Pre-warm response pools
7596
for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ {
7697
cohereResponsePool.Put(&cohere.CohereChatResponse{})
98+
cohereEmbeddingResponsePool.Put(&cohere.CohereEmbeddingResponse{})
7799
}
78100

79101
// Set default BaseURL if not provided
@@ -97,6 +119,64 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider {
97119
return getProviderName(schemas.Cohere, provider.customProviderConfig)
98120
}
99121

122+
// completeRequest sends a request to Cohere's API and handles the response.
123+
// It constructs the API URL, sets up authentication, and processes the response.
124+
// Returns the response body or an error if the request fails.
125+
func (provider *CohereProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) {
126+
// Marshal the request body
127+
jsonData, err := sonic.Marshal(requestBody)
128+
if err != nil {
129+
return nil, 0, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
130+
}
131+
132+
// Create the request with the JSON body
133+
req := fasthttp.AcquireRequest()
134+
resp := fasthttp.AcquireResponse()
135+
defer fasthttp.ReleaseRequest(req)
136+
defer fasthttp.ReleaseResponse(resp)
137+
138+
// Set any extra headers from network config
139+
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
140+
141+
req.SetRequestURI(url)
142+
req.Header.SetMethod(http.MethodPost)
143+
req.Header.SetContentType("application/json")
144+
req.Header.Set("Authorization", "Bearer "+key)
145+
146+
req.SetBody(jsonData)
147+
148+
// Send the request
149+
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
150+
if bifrostErr != nil {
151+
return nil, latency, bifrostErr
152+
}
153+
154+
// Handle error response
155+
if resp.StatusCode() != fasthttp.StatusOK {
156+
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())))
157+
158+
var errorResp cohere.CohereError
159+
160+
bifrostErr := handleProviderAPIError(resp, &errorResp)
161+
bifrostErr.Type = &errorResp.Type
162+
if bifrostErr.Error == nil {
163+
bifrostErr.Error = &schemas.ErrorField{}
164+
}
165+
bifrostErr.Error.Message = errorResp.Message
166+
if errorResp.Code != nil {
167+
bifrostErr.Error.Code = errorResp.Code
168+
}
169+
170+
return nil, latency, bifrostErr
171+
}
172+
173+
// Read the response body and copy it before releasing the response
174+
// to avoid use-after-free since resp.Body() references fasthttp's internal buffer
175+
bodyCopy := append([]byte(nil), resp.Body()...)
176+
177+
return bodyCopy, latency, nil
178+
}
179+
100180
// listModelsByKey performs a list models request for a single key.
101181
// Returns the response and latency, or an error if the request fails.
102182
func (provider *CohereProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
@@ -210,102 +290,34 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.
210290
return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName)
211291
}
212292

213-
cohereResponse, rawResponse, latency, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key)
293+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/chat", key.Value)
214294
if err != nil {
215295
return nil, err
216296
}
217297

218-
// Convert Cohere v2 response to Bifrost response
219-
response := cohereResponse.ToBifrostChatResponse()
220-
221-
response.Model = request.Model
222-
response.ExtraFields.Provider = providerName
223-
response.ExtraFields.ModelRequested = request.Model
224-
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
225-
response.ExtraFields.Latency = latency.Milliseconds()
226-
227-
if provider.sendBackRawResponse {
228-
response.ExtraFields.RawResponse = rawResponse
229-
}
230-
231-
return response, nil
232-
}
233-
234-
func (provider *CohereProvider) handleCohereChatCompletionRequest(ctx context.Context, reqBody *cohere.CohereChatRequest, key schemas.Key) (*cohere.CohereChatResponse, interface{}, time.Duration, *schemas.BifrostError) {
235-
providerName := provider.GetProviderKey()
236-
237-
// Marshal request body
238-
jsonBody, err := sonic.Marshal(reqBody)
239-
if err != nil {
240-
return nil, nil, time.Duration(0), &schemas.BifrostError{
241-
IsBifrostError: true,
242-
Error: &schemas.ErrorField{
243-
Message: schemas.ErrProviderJSONMarshaling,
244-
Error: err,
245-
},
246-
}
247-
}
248-
249-
// Create request
250-
req := fasthttp.AcquireRequest()
251-
resp := fasthttp.AcquireResponse()
252-
defer fasthttp.ReleaseRequest(req)
253-
defer fasthttp.ReleaseResponse(resp)
254-
255-
// Set any extra headers from network config
256-
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
298+
// Create response object from pool
299+
response := acquireCohereResponse()
300+
defer releaseCohereResponse(response)
257301

258-
req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/chat")
259-
req.Header.SetMethod(http.MethodPost)
260-
req.Header.SetContentType("application/json")
261-
req.Header.Set("Authorization", "Bearer "+key.Value)
262-
263-
req.SetBody(jsonBody)
264-
265-
// Make request
266-
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
302+
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
267303
if bifrostErr != nil {
268-
return nil, nil, latency, bifrostErr
304+
return nil, bifrostErr
269305
}
270306

271-
// Handle error response
272-
if resp.StatusCode() != fasthttp.StatusOK {
273-
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
274-
275-
var errorResp cohere.CohereError
276-
bifrostErr := handleProviderAPIError(resp, &errorResp)
277-
bifrostErr.Error.Message = errorResp.Message
278-
279-
return nil, nil, latency, bifrostErr
280-
}
307+
bifrostResponse := response.ToBifrostChatResponse(request.Model)
281308

282-
// Parse Cohere v2 response
283-
var cohereResponse cohere.CohereChatResponse
284-
if err := sonic.Unmarshal(resp.Body(), &cohereResponse); err != nil {
285-
return nil, nil, latency, &schemas.BifrostError{
286-
IsBifrostError: true,
287-
Error: &schemas.ErrorField{
288-
Message: "error parsing Cohere v2 response",
289-
Error: err,
290-
},
291-
}
292-
}
309+
// Set ExtraFields
310+
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
311+
bifrostResponse.ExtraFields.ModelRequested = request.Model
312+
bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest
313+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
293314

294-
// Parse raw response for sendBackRawResponse
295-
var rawResponse interface{}
315+
// Set raw response if enabled
296316
if provider.sendBackRawResponse {
297-
if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil {
298-
return nil, nil, latency, &schemas.BifrostError{
299-
IsBifrostError: true,
300-
Error: &schemas.ErrorField{
301-
Message: "error parsing raw response",
302-
Error: err,
303-
},
304-
}
305-
}
317+
bifrostResponse.ExtraFields.RawResponse = rawResponse
306318
}
307319

308-
return &cohereResponse, rawResponse, latency, nil
320+
return bifrostResponse, nil
309321
}
310322

311323
// ChatCompletionStream performs a streaming chat completion request to the Cohere API.
@@ -572,24 +584,34 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key,
572584
return nil, newBifrostOperationError("responses input is not provided", nil, providerName)
573585
}
574586

575-
cohereResponse, rawResponse, latency, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key)
587+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/chat", key.Value)
576588
if err != nil {
577589
return nil, err
578590
}
579591

580-
// Convert Cohere v2 response to Bifrost response
581-
response := cohereResponse.ToResponsesBifrostResponsesResponse()
592+
// Create response object from pool
593+
response := acquireCohereResponse()
594+
defer releaseCohereResponse(response)
582595

583-
response.ExtraFields.Provider = providerName
584-
response.ExtraFields.ModelRequested = request.Model
585-
response.ExtraFields.RequestType = schemas.ResponsesRequest
586-
response.ExtraFields.Latency = latency.Milliseconds()
596+
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
597+
if bifrostErr != nil {
598+
return nil, bifrostErr
599+
}
600+
601+
bifrostResponse := response.ToBifrostResponsesResponse()
587602

603+
// Set ExtraFields
604+
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
605+
bifrostResponse.ExtraFields.ModelRequested = request.Model
606+
bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest
607+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
608+
609+
// Set raw response if enabled
588610
if provider.sendBackRawResponse {
589-
response.ExtraFields.RawResponse = rawResponse
611+
bifrostResponse.ExtraFields.RawResponse = rawResponse
590612
}
591613

592-
return response, nil
614+
return bifrostResponse, nil
593615
}
594616

595617
// ResponsesStream performs a streaming responses request to the Cohere API.
@@ -792,72 +814,34 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key,
792814
return nil, newBifrostOperationError("embedding input is not provided", nil, providerName)
793815
}
794816

795-
// Marshal request body
796-
jsonBody, err := sonic.Marshal(reqBody)
817+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/embed", key.Value)
797818
if err != nil {
798-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName)
819+
return nil, err
799820
}
800821

801-
// Create request
802-
req := fasthttp.AcquireRequest()
803-
resp := fasthttp.AcquireResponse()
804-
defer fasthttp.ReleaseRequest(req)
805-
defer fasthttp.ReleaseResponse(resp)
806-
807-
// Set any extra headers from network config
808-
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
822+
// Create response object from pool
823+
response := acquireCohereEmbeddingResponse()
824+
defer releaseCohereEmbeddingResponse(response)
809825

810-
req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed")
811-
req.Header.SetMethod(http.MethodPost)
812-
req.Header.SetContentType("application/json")
813-
req.Header.Set("Authorization", "Bearer "+key.Value)
814-
815-
req.SetBody(jsonBody)
816-
817-
// Make request
818-
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
826+
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
819827
if bifrostErr != nil {
820828
return nil, bifrostErr
821829
}
822830

823-
// Handle error response
824-
if resp.StatusCode() != fasthttp.StatusOK {
825-
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
826-
827-
var errorResp cohere.CohereError
828-
bifrostErr := handleProviderAPIError(resp, &errorResp)
829-
bifrostErr.Error.Message = errorResp.Message
830-
831-
return nil, bifrostErr
832-
}
833-
834-
// Parse response
835-
var cohereResp cohere.CohereEmbeddingResponse
836-
if err := sonic.Unmarshal(resp.Body(), &cohereResp); err != nil {
837-
return nil, newBifrostOperationError("error parsing embedding response", err, providerName)
838-
}
839-
840-
// Parse raw response for consistent format
841-
var rawResponse interface{}
842-
if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil {
843-
return nil, newBifrostOperationError("error parsing raw response for embedding", err, providerName)
844-
}
831+
bifrostResponse := response.ToBifrostEmbeddingResponse()
845832

846-
// Create BifrostResponse
847-
response := cohereResp.ToBifrostEmbeddingResponse()
833+
// Set ExtraFields
834+
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
835+
bifrostResponse.ExtraFields.ModelRequested = request.Model
836+
bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest
837+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
848838

849-
response.Model = request.Model
850-
response.ExtraFields.Provider = providerName
851-
response.ExtraFields.ModelRequested = request.Model
852-
response.ExtraFields.RequestType = schemas.EmbeddingRequest
853-
response.ExtraFields.Latency = latency.Milliseconds()
854-
855-
// Only include RawResponse if sendBackRawResponse is enabled
839+
// Set raw response if enabled
856840
if provider.sendBackRawResponse {
857-
response.ExtraFields.RawResponse = rawResponse
841+
bifrostResponse.ExtraFields.RawResponse = rawResponse
858842
}
859843

860-
return response, nil
844+
return bifrostResponse, nil
861845
}
862846

863847
// Speech is not supported by the Cohere provider.

core/providers/utils.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,8 @@ func extractSuccessfulListModelsResponses(
766766
}
767767

768768
// handleMultipleListModelsRequests handles multiple list models requests concurrently for different keys.
769+
// It launches concurrent requests for all keys and waits for all goroutines to complete.
770+
// It returns the aggregated response or an error if the request fails.
769771
func handleMultipleListModelsRequests(
770772
ctx context.Context,
771773
keys []schemas.Key,

core/schemas/providers/anthropic/responses.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *A
190190
return anthropicReq
191191
}
192192

193-
// ToResponsesBifrostResponse converts an Anthropic response to BifrostResponse with Responses structure
193+
// ToBifrostResponsesResponse converts an Anthropic response to BifrostResponse with Responses structure
194194
func (response *AnthropicMessageResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse {
195195
if response == nil {
196196
return nil

core/schemas/providers/cohere/chat.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,14 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Cohe
178178
}
179179

180180
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
181-
func (response *CohereChatResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
181+
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
182182
if response == nil {
183183
return nil
184184
}
185185

186186
bifrostResponse := &schemas.BifrostChatResponse{
187187
ID: response.ID,
188+
Model: model,
188189
Object: "chat.completion",
189190
Choices: []schemas.BifrostResponseChoice{
190191
{

0 commit comments

Comments
 (0)