@@ -30,6 +30,27 @@ var cohereResponsePool = sync.Pool{
3030 },
3131}
3232
33+ // cohereEmbeddingResponsePool provides a pool for Cohere embedding response objects.
34+ var cohereEmbeddingResponsePool = sync.Pool {
35+ New : func () interface {} {
36+ return & cohere.CohereEmbeddingResponse {}
37+ },
38+ }
39+
40+ // acquireCohereEmbeddingResponse gets a Cohere embedding response from the pool and resets it.
41+ func acquireCohereEmbeddingResponse () * cohere.CohereEmbeddingResponse {
42+ resp := cohereEmbeddingResponsePool .Get ().(* cohere.CohereEmbeddingResponse )
43+ * resp = cohere.CohereEmbeddingResponse {} // Reset the struct
44+ return resp
45+ }
46+
47+ // releaseCohereEmbeddingResponse returns a Cohere embedding response to the pool.
48+ func releaseCohereEmbeddingResponse (resp * cohere.CohereEmbeddingResponse ) {
49+ if resp != nil {
50+ cohereEmbeddingResponsePool .Put (resp )
51+ }
52+ }
53+
3354// acquireCohereResponse gets a Cohere v2 response from the pool and resets it.
3455func acquireCohereResponse () * cohere.CohereChatResponse {
3556 resp := cohereResponsePool .Get ().(* cohere.CohereChatResponse )
@@ -74,6 +95,7 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *C
7495 // Pre-warm response pools
7596 for i := 0 ; i < config .ConcurrencyAndBufferSize .Concurrency ; i ++ {
7697 cohereResponsePool .Put (& cohere.CohereChatResponse {})
98+ cohereEmbeddingResponsePool .Put (& cohere.CohereEmbeddingResponse {})
7799 }
78100
79101 // Set default BaseURL if not provided
@@ -97,6 +119,64 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider {
97119 return getProviderName (schemas .Cohere , provider .customProviderConfig )
98120}
99121
122+ // completeRequest sends a request to Cohere's API and handles the response.
123+ // It constructs the API URL, sets up authentication, and processes the response.
124+ // Returns the response body or an error if the request fails.
125+ func (provider * CohereProvider ) completeRequest (ctx context.Context , requestBody interface {}, url string , key string ) ([]byte , time.Duration , * schemas.BifrostError ) {
126+ // Marshal the request body
127+ jsonData , err := sonic .Marshal (requestBody )
128+ if err != nil {
129+ return nil , 0 , newBifrostOperationError (schemas .ErrProviderJSONMarshaling , err , provider .GetProviderKey ())
130+ }
131+
132+ // Create the request with the JSON body
133+ req := fasthttp .AcquireRequest ()
134+ resp := fasthttp .AcquireResponse ()
135+ defer fasthttp .ReleaseRequest (req )
136+ defer fasthttp .ReleaseResponse (resp )
137+
138+ // Set any extra headers from network config
139+ setExtraHeaders (req , provider .networkConfig .ExtraHeaders , nil )
140+
141+ req .SetRequestURI (url )
142+ req .Header .SetMethod (http .MethodPost )
143+ req .Header .SetContentType ("application/json" )
144+ req .Header .Set ("Authorization" , "Bearer " + key )
145+
146+ req .SetBody (jsonData )
147+
148+ // Send the request
149+ latency , bifrostErr := makeRequestWithContext (ctx , provider .client , req , resp )
150+ if bifrostErr != nil {
151+ return nil , latency , bifrostErr
152+ }
153+
154+ // Handle error response
155+ if resp .StatusCode () != fasthttp .StatusOK {
156+ provider .logger .Debug (fmt .Sprintf ("error from %s provider: %s" , provider .GetProviderKey (), string (resp .Body ())))
157+
158+ var errorResp cohere.CohereError
159+
160+ bifrostErr := handleProviderAPIError (resp , & errorResp )
161+ bifrostErr .Type = & errorResp .Type
162+ if bifrostErr .Error == nil {
163+ bifrostErr .Error = & schemas.ErrorField {}
164+ }
165+ bifrostErr .Error .Message = errorResp .Message
166+ if errorResp .Code != nil {
167+ bifrostErr .Error .Code = errorResp .Code
168+ }
169+
170+ return nil , latency , bifrostErr
171+ }
172+
173+ // Read the response body and copy it before releasing the response
174+ // to avoid use-after-free since resp.Body() references fasthttp's internal buffer
175+ bodyCopy := append ([]byte (nil ), resp .Body ()... )
176+
177+ return bodyCopy , latency , nil
178+ }
179+
100180// listModelsByKey performs a list models request for a single key.
101181// Returns the response and latency, or an error if the request fails.
102182func (provider * CohereProvider ) listModelsByKey (ctx context.Context , key schemas.Key , request * schemas.BifrostListModelsRequest ) (* schemas.BifrostListModelsResponse , * schemas.BifrostError ) {
@@ -210,102 +290,34 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.
210290 return nil , newBifrostOperationError ("chat completion input is not provided" , nil , providerName )
211291 }
212292
213- cohereResponse , rawResponse , latency , err := provider .handleCohereChatCompletionRequest (ctx , reqBody , key )
293+ responseBody , latency , err := provider .completeRequest (ctx , reqBody , provider . networkConfig . BaseURL + "/v2/chat" , key . Value )
214294 if err != nil {
215295 return nil , err
216296 }
217297
218- // Convert Cohere v2 response to Bifrost response
219- response := cohereResponse .ToBifrostChatResponse ()
220-
221- response .Model = request .Model
222- response .ExtraFields .Provider = providerName
223- response .ExtraFields .ModelRequested = request .Model
224- response .ExtraFields .RequestType = schemas .ChatCompletionRequest
225- response .ExtraFields .Latency = latency .Milliseconds ()
226-
227- if provider .sendBackRawResponse {
228- response .ExtraFields .RawResponse = rawResponse
229- }
230-
231- return response , nil
232- }
233-
234- func (provider * CohereProvider ) handleCohereChatCompletionRequest (ctx context.Context , reqBody * cohere.CohereChatRequest , key schemas.Key ) (* cohere.CohereChatResponse , interface {}, time.Duration , * schemas.BifrostError ) {
235- providerName := provider .GetProviderKey ()
236-
237- // Marshal request body
238- jsonBody , err := sonic .Marshal (reqBody )
239- if err != nil {
240- return nil , nil , time .Duration (0 ), & schemas.BifrostError {
241- IsBifrostError : true ,
242- Error : & schemas.ErrorField {
243- Message : schemas .ErrProviderJSONMarshaling ,
244- Error : err ,
245- },
246- }
247- }
248-
249- // Create request
250- req := fasthttp .AcquireRequest ()
251- resp := fasthttp .AcquireResponse ()
252- defer fasthttp .ReleaseRequest (req )
253- defer fasthttp .ReleaseResponse (resp )
254-
255- // Set any extra headers from network config
256- setExtraHeaders (req , provider .networkConfig .ExtraHeaders , nil )
298+ // Create response object from pool
299+ response := acquireCohereResponse ()
300+ defer releaseCohereResponse (response )
257301
258- req .SetRequestURI (provider .networkConfig .BaseURL + "/v2/chat" )
259- req .Header .SetMethod (http .MethodPost )
260- req .Header .SetContentType ("application/json" )
261- req .Header .Set ("Authorization" , "Bearer " + key .Value )
262-
263- req .SetBody (jsonBody )
264-
265- // Make request
266- latency , bifrostErr := makeRequestWithContext (ctx , provider .client , req , resp )
302+ rawResponse , bifrostErr := handleProviderResponse (responseBody , response , provider .sendBackRawResponse )
267303 if bifrostErr != nil {
268- return nil , nil , latency , bifrostErr
304+ return nil , bifrostErr
269305 }
270306
271- // Handle error response
272- if resp .StatusCode () != fasthttp .StatusOK {
273- provider .logger .Debug (fmt .Sprintf ("error from %s provider: %s" , providerName , string (resp .Body ())))
274-
275- var errorResp cohere.CohereError
276- bifrostErr := handleProviderAPIError (resp , & errorResp )
277- bifrostErr .Error .Message = errorResp .Message
278-
279- return nil , nil , latency , bifrostErr
280- }
307+ bifrostResponse := response .ToBifrostChatResponse (request .Model )
281308
282- // Parse Cohere v2 response
283- var cohereResponse cohere.CohereChatResponse
284- if err := sonic .Unmarshal (resp .Body (), & cohereResponse ); err != nil {
285- return nil , nil , latency , & schemas.BifrostError {
286- IsBifrostError : true ,
287- Error : & schemas.ErrorField {
288- Message : "error parsing Cohere v2 response" ,
289- Error : err ,
290- },
291- }
292- }
309+ // Set ExtraFields
310+ bifrostResponse .ExtraFields .Provider = provider .GetProviderKey ()
311+ bifrostResponse .ExtraFields .ModelRequested = request .Model
312+ bifrostResponse .ExtraFields .RequestType = schemas .ChatCompletionRequest
313+ bifrostResponse .ExtraFields .Latency = latency .Milliseconds ()
293314
294- // Parse raw response for sendBackRawResponse
295- var rawResponse interface {}
315+ // Set raw response if enabled
296316 if provider .sendBackRawResponse {
297- if err := sonic .Unmarshal (resp .Body (), & rawResponse ); err != nil {
298- return nil , nil , latency , & schemas.BifrostError {
299- IsBifrostError : true ,
300- Error : & schemas.ErrorField {
301- Message : "error parsing raw response" ,
302- Error : err ,
303- },
304- }
305- }
317+ bifrostResponse .ExtraFields .RawResponse = rawResponse
306318 }
307319
308- return & cohereResponse , rawResponse , latency , nil
320+ return bifrostResponse , nil
309321}
310322
311323// ChatCompletionStream performs a streaming chat completion request to the Cohere API.
@@ -572,24 +584,34 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key,
572584 return nil , newBifrostOperationError ("responses input is not provided" , nil , providerName )
573585 }
574586
575- cohereResponse , rawResponse , latency , err := provider .handleCohereChatCompletionRequest (ctx , reqBody , key )
587+ responseBody , latency , err := provider .completeRequest (ctx , reqBody , provider . networkConfig . BaseURL + "/v2/chat" , key . Value )
576588 if err != nil {
577589 return nil , err
578590 }
579591
580- // Convert Cohere v2 response to Bifrost response
581- response := cohereResponse .ToResponsesBifrostResponsesResponse ()
592+ // Create response object from pool
593+ response := acquireCohereResponse ()
594+ defer releaseCohereResponse (response )
582595
583- response .ExtraFields .Provider = providerName
584- response .ExtraFields .ModelRequested = request .Model
585- response .ExtraFields .RequestType = schemas .ResponsesRequest
586- response .ExtraFields .Latency = latency .Milliseconds ()
596+ rawResponse , bifrostErr := handleProviderResponse (responseBody , response , provider .sendBackRawResponse )
597+ if bifrostErr != nil {
598+ return nil , bifrostErr
599+ }
600+
601+ bifrostResponse := response .ToBifrostResponsesResponse ()
587602
603+ // Set ExtraFields
604+ bifrostResponse .ExtraFields .Provider = provider .GetProviderKey ()
605+ bifrostResponse .ExtraFields .ModelRequested = request .Model
606+ bifrostResponse .ExtraFields .RequestType = schemas .ResponsesRequest
607+ bifrostResponse .ExtraFields .Latency = latency .Milliseconds ()
608+
609+ // Set raw response if enabled
588610 if provider .sendBackRawResponse {
589- response .ExtraFields .RawResponse = rawResponse
611+ bifrostResponse .ExtraFields .RawResponse = rawResponse
590612 }
591613
592- return response , nil
614+ return bifrostResponse , nil
593615}
594616
595617// ResponsesStream performs a streaming responses request to the Cohere API.
@@ -792,72 +814,34 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key,
792814 return nil , newBifrostOperationError ("embedding input is not provided" , nil , providerName )
793815 }
794816
795- // Marshal request body
796- jsonBody , err := sonic .Marshal (reqBody )
817+ responseBody , latency , err := provider .completeRequest (ctx , reqBody , provider .networkConfig .BaseURL + "/v2/embed" , key .Value )
797818 if err != nil {
798- return nil , newBifrostOperationError ( schemas . ErrProviderJSONMarshaling , err , providerName )
819+ return nil , err
799820 }
800821
801- // Create request
802- req := fasthttp .AcquireRequest ()
803- resp := fasthttp .AcquireResponse ()
804- defer fasthttp .ReleaseRequest (req )
805- defer fasthttp .ReleaseResponse (resp )
806-
807- // Set any extra headers from network config
808- setExtraHeaders (req , provider .networkConfig .ExtraHeaders , nil )
822+ // Create response object from pool
823+ response := acquireCohereEmbeddingResponse ()
824+ defer releaseCohereEmbeddingResponse (response )
809825
810- req .SetRequestURI (provider .networkConfig .BaseURL + "/v2/embed" )
811- req .Header .SetMethod (http .MethodPost )
812- req .Header .SetContentType ("application/json" )
813- req .Header .Set ("Authorization" , "Bearer " + key .Value )
814-
815- req .SetBody (jsonBody )
816-
817- // Make request
818- latency , bifrostErr := makeRequestWithContext (ctx , provider .client , req , resp )
826+ rawResponse , bifrostErr := handleProviderResponse (responseBody , response , provider .sendBackRawResponse )
819827 if bifrostErr != nil {
820828 return nil , bifrostErr
821829 }
822830
823- // Handle error response
824- if resp .StatusCode () != fasthttp .StatusOK {
825- provider .logger .Debug (fmt .Sprintf ("error from %s provider: %s" , providerName , string (resp .Body ())))
826-
827- var errorResp cohere.CohereError
828- bifrostErr := handleProviderAPIError (resp , & errorResp )
829- bifrostErr .Error .Message = errorResp .Message
830-
831- return nil , bifrostErr
832- }
833-
834- // Parse response
835- var cohereResp cohere.CohereEmbeddingResponse
836- if err := sonic .Unmarshal (resp .Body (), & cohereResp ); err != nil {
837- return nil , newBifrostOperationError ("error parsing embedding response" , err , providerName )
838- }
839-
840- // Parse raw response for consistent format
841- var rawResponse interface {}
842- if err := sonic .Unmarshal (resp .Body (), & rawResponse ); err != nil {
843- return nil , newBifrostOperationError ("error parsing raw response for embedding" , err , providerName )
844- }
831+ bifrostResponse := response .ToBifrostEmbeddingResponse ()
845832
846- // Create BifrostResponse
847- response := cohereResp .ToBifrostEmbeddingResponse ()
833+ // Set ExtraFields
834+ bifrostResponse .ExtraFields .Provider = provider .GetProviderKey ()
835+ bifrostResponse .ExtraFields .ModelRequested = request .Model
836+ bifrostResponse .ExtraFields .RequestType = schemas .EmbeddingRequest
837+ bifrostResponse .ExtraFields .Latency = latency .Milliseconds ()
848838
849- response .Model = request .Model
850- response .ExtraFields .Provider = providerName
851- response .ExtraFields .ModelRequested = request .Model
852- response .ExtraFields .RequestType = schemas .EmbeddingRequest
853- response .ExtraFields .Latency = latency .Milliseconds ()
854-
855- // Only include RawResponse if sendBackRawResponse is enabled
839+ // Set raw response if enabled
856840 if provider .sendBackRawResponse {
857- response .ExtraFields .RawResponse = rawResponse
841+ bifrostResponse .ExtraFields .RawResponse = rawResponse
858842 }
859843
860- return response , nil
844+ return bifrostResponse , nil
861845}
862846
863847// Speech is not supported by the Cohere provider.
0 commit comments