diff --git a/core/changelog.md b/core/changelog.md index e05637912..e6d310008 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,4 +1,6 @@ -- feat: Adds support for models list for providers +- feat: Use all keys for list models request +- refactor: Cohere provider to use completeRequest and response pooling for all requests +- chore: Added id, object, and model fields to Chat Completion responses from Bedrock and Cohere providers \ No newline at end of file diff --git a/core/providers/cohere.go b/core/providers/cohere.go index f50062098..2db09c8b4 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -30,6 +30,27 @@ var cohereResponsePool = sync.Pool{ }, } +// cohereEmbeddingResponsePool provides a pool for Cohere embedding response objects. +var cohereEmbeddingResponsePool = sync.Pool{ + New: func() interface{} { + return &cohere.CohereEmbeddingResponse{} + }, +} + +// acquireCohereEmbeddingResponse gets a Cohere embedding response from the pool and resets it. +func acquireCohereEmbeddingResponse() *cohere.CohereEmbeddingResponse { + resp := cohereEmbeddingResponsePool.Get().(*cohere.CohereEmbeddingResponse) + *resp = cohere.CohereEmbeddingResponse{} // Reset the struct + return resp +} + +// releaseCohereEmbeddingResponse returns a Cohere embedding response to the pool. +func releaseCohereEmbeddingResponse(resp *cohere.CohereEmbeddingResponse) { + if resp != nil { + cohereEmbeddingResponsePool.Put(resp) + } +} + // acquireCohereResponse gets a Cohere v2 response from the pool and resets it. func acquireCohereResponse() *cohere.CohereChatResponse { resp := cohereResponsePool.Get().(*cohere.CohereChatResponse) @@ -74,6 +95,7 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *C // Pre-warm response pools for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { cohereResponsePool.Put(&cohere.CohereChatResponse{}) + cohereEmbeddingResponsePool.Put(&cohere.CohereEmbeddingResponse{}) } // Set default BaseURL if not provided @@ -97,6 +119,64 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { return getProviderName(schemas.Cohere, provider.customProviderConfig) } +// completeRequest sends a request to Cohere's API and handles the response. +// It constructs the API URL, sets up authentication, and processes the response. +// Returns the response body or an error if the request fails. +func (provider *CohereProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) { + // Marshal the request body + jsonData, err := sonic.Marshal(requestBody) + if err != nil { + return nil, 0, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey()) + } + + // Create the request with the JSON body + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + + req.SetBody(jsonData) + + // Send the request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, latency, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) + + var errorResp cohere.CohereError + + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Type = &errorResp.Type + if bifrostErr.Error == nil { + bifrostErr.Error = &schemas.ErrorField{} + } + bifrostErr.Error.Message = errorResp.Message + if errorResp.Code != nil { + bifrostErr.Error.Code = errorResp.Code + } + + return nil, latency, bifrostErr + } + + // Read the response body and copy it before releasing the response + // to avoid use-after-free since resp.Body() references fasthttp's internal buffer + bodyCopy := append([]byte(nil), resp.Body()...) + + return bodyCopy, latency, nil +} + // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *CohereProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { @@ -210,102 +290,34 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) } - cohereResponse, rawResponse, latency, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key) + responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/chat", key.Value) if err != nil { return nil, err } - // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostChatResponse() - - response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Latency = latency.Milliseconds() - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - return response, nil -} - -func (provider *CohereProvider) handleCohereChatCompletionRequest(ctx context.Context, reqBody *cohere.CohereChatRequest, key schemas.Key) (*cohere.CohereChatResponse, interface{}, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() + // Create response object from pool + response := acquireCohereResponse() + defer releaseCohereResponse(response) - // Marshal request body - jsonBody, err := sonic.Marshal(reqBody) - if err != nil { - return nil, nil, time.Duration(0), &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/chat") - req.Header.SetMethod(http.MethodPost) - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { - return nil, nil, latency, bifrostErr + return nil, bifrostErr } - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - - var errorResp cohere.CohereError - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = errorResp.Message - - return nil, nil, latency, bifrostErr - } + bifrostResponse := response.ToBifrostChatResponse(request.Model) - // Parse Cohere v2 response - var cohereResponse cohere.CohereChatResponse - if err := sonic.Unmarshal(resp.Body(), &cohereResponse); err != nil { - return nil, nil, latency, &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: "error parsing Cohere v2 response", - Error: err, - }, - } - } + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - // Parse raw response for sendBackRawResponse - var rawResponse interface{} + // Set raw response if enabled if provider.sendBackRawResponse { - if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil { - return nil, nil, latency, &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: "error parsing raw response", - Error: err, - }, - } - } + bifrostResponse.ExtraFields.RawResponse = rawResponse } - return &cohereResponse, rawResponse, latency, nil + return bifrostResponse, nil } // ChatCompletionStream performs a streaming chat completion request to the Cohere API. @@ -393,18 +405,19 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo // Create response channel responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) - chunkIndex := -1 - // Start streaming in a goroutine go func() { defer close(responseChan) defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) - var responseID string + chunkIndex := 0 + startTime := time.Now() lastChunkTime := startTime + var responseID string + for scanner.Scan() { line := scanner.Text() @@ -437,113 +450,37 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo responseID = *event.ID } - // Create base response with current responseID - response := &schemas.BifrostChatResponse{ - ID: responseID, - Object: "chat.completion.chunk", - Model: request.Model, - Choices: []schemas.BifrostResponseChoice{ - { - Index: 0, - ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ - Delta: &schemas.ChatStreamResponseChoiceDelta{}, - }, - }, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ + response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() + if response != nil { + response.ID = responseID + response.ExtraFields = schemas.BifrostResponseExtraFields{ RequestType: schemas.ChatCompletionStreamRequest, Provider: providerName, ModelRequested: request.Model, ChunkIndex: chunkIndex, Latency: time.Since(lastChunkTime).Milliseconds(), - }, - } - lastChunkTime = time.Now() - - switch event.Type { - case cohere.StreamEventMessageStart: - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.Role != nil { - response.Choices[0].ChatStreamResponseChoice.Delta.Role = event.Delta.Message.Role } + lastChunkTime = time.Now() + chunkIndex++ - case cohere.StreamEventContentDelta: - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.Content != nil && event.Delta.Message.Content.CohereStreamContentObject != nil && event.Delta.Message.Content.CohereStreamContentObject.Text != nil { - // Try to cast content to CohereStreamContent - response.Choices[0].ChatStreamResponseChoice.Delta.Content = event.Delta.Message.Content.CohereStreamContentObject.Text + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = jsonData } - case cohere.StreamEventToolPlanDelta: - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolPlan != nil { - response.Choices[0].ChatStreamResponseChoice.Delta.Thought = event.Delta.Message.ToolPlan + processAndSendResponse(ctx, postHookRunner, getBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan, provider.logger) + if isLastChunk { + break } - - case cohere.StreamEventContentStart: - // Content start event - just continue, actual content comes in content-delta - - case cohere.StreamEventToolCallStart, cohere.StreamEventToolCallDelta: - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil && event.Delta.Message.ToolCalls.CohereToolCallObject != nil { - // Handle single tool call object (tool-call-start/delta events) - cohereToolCall := event.Delta.Message.ToolCalls.CohereToolCallObject - toolCall := schemas.ChatAssistantMessageToolCall{} - - if cohereToolCall.ID != nil { - toolCall.ID = cohereToolCall.ID - } - - if cohereToolCall.Function != nil { - if cohereToolCall.Function.Name != nil { - toolCall.Function.Name = cohereToolCall.Function.Name - } - toolCall.Function.Arguments = cohereToolCall.Function.Arguments - } - - response.Choices[0].ChatStreamResponseChoice.Delta.ToolCalls = []schemas.ChatAssistantMessageToolCall{toolCall} - } - - case cohere.StreamEventMessageEnd: - if event.Delta != nil { - // Set finish reason - if event.Delta.FinishReason != nil { - finishReason := string(*event.Delta.FinishReason) - response.Choices[0].FinishReason = &finishReason - } - - // Set usage information - if event.Delta.Usage != nil { - usage := &schemas.BifrostLLMUsage{} - if event.Delta.Usage.Tokens != nil { - if event.Delta.Usage.Tokens.InputTokens != nil { - usage.PromptTokens = int(*event.Delta.Usage.Tokens.InputTokens) - } - if event.Delta.Usage.Tokens.OutputTokens != nil { - usage.CompletionTokens = int(*event.Delta.Usage.Tokens.OutputTokens) - } - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - } - response.Usage = usage - } - - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - } - - case cohere.StreamEventToolCallEnd, cohere.StreamEventContentEnd: - // These events just signal completion, no additional data needed - - default: - provider.logger.Debug(fmt.Sprintf("Unknown v2 stream event type: %s", event.Type)) - continue } + if bifrostErr != nil { + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + } - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = jsonData - } - - processAndSendResponse(ctx, postHookRunner, getBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan, provider.logger) - - // End stream after message-end - if event.Type == cohere.StreamEventMessageEnd { - return + processAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + break } } } @@ -572,24 +509,34 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, return nil, newBifrostOperationError("responses input is not provided", nil, providerName) } - cohereResponse, rawResponse, latency, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key) + responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/chat", key.Value) if err != nil { return nil, err } - // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToResponsesBifrostResponsesResponse() + // Create response object from pool + response := acquireCohereResponse() + defer releaseCohereResponse(response) - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Latency = latency.Milliseconds() + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + bifrostResponse := response.ToBifrostResponsesResponse() + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() + // Set raw response if enabled if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse + bifrostResponse.ExtraFields.RawResponse = rawResponse } - return response, nil + return bifrostResponse, nil } // ResponsesStream performs a streaming responses request to the Cohere API. @@ -792,72 +739,34 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, return nil, newBifrostOperationError("embedding input is not provided", nil, providerName) } - // Marshal request body - jsonBody, err := sonic.Marshal(reqBody) + responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v2/embed", key.Value) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + return nil, err } - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) + // Create response object from pool + response := acquireCohereEmbeddingResponse() + defer releaseCohereEmbeddingResponse(response) - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed") - req.Header.SetMethod(http.MethodPost) - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + bifrostResponse := response.ToBifrostEmbeddingResponse() - var errorResp cohere.CohereError - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = errorResp.Message + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - return nil, bifrostErr - } - - // Parse response - var cohereResp cohere.CohereEmbeddingResponse - if err := sonic.Unmarshal(resp.Body(), &cohereResp); err != nil { - return nil, newBifrostOperationError("error parsing embedding response", err, providerName) - } - - // Parse raw response for consistent format - var rawResponse interface{} - if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil { - return nil, newBifrostOperationError("error parsing raw response for embedding", err, providerName) - } - - // Create BifrostResponse - response := cohereResp.ToBifrostEmbeddingResponse() - - response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.EmbeddingRequest - response.ExtraFields.Latency = latency.Milliseconds() - - // Only include RawResponse if sendBackRawResponse is enabled + // Set raw response if enabled if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse + bifrostResponse.ExtraFields.RawResponse = rawResponse } - return response, nil + return bifrostResponse, nil } // Speech is not supported by the Cohere provider. diff --git a/core/providers/utils.go b/core/providers/utils.go index b8a1e465c..a9d6b3df4 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -766,6 +766,8 @@ func extractSuccessfulListModelsResponses( } // handleMultipleListModelsRequests handles multiple list models requests concurrently for different keys. +// It launches concurrent requests for all keys and waits for all goroutines to complete. +// It returns the aggregated response or an error if the request fails. func handleMultipleListModelsRequests( ctx context.Context, keys []schemas.Key, diff --git a/core/schemas/providers/anthropic/responses.go b/core/schemas/providers/anthropic/responses.go index fd9624f8c..049e085c5 100644 --- a/core/schemas/providers/anthropic/responses.go +++ b/core/schemas/providers/anthropic/responses.go @@ -190,7 +190,7 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *A return anthropicReq } -// ToResponsesBifrostResponse converts an Anthropic response to BifrostResponse with Responses structure +// ToBifrostResponsesResponse converts an Anthropic response to BifrostResponse with Responses structure func (response *AnthropicMessageResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse { if response == nil { return nil diff --git a/core/schemas/providers/cohere/chat.go b/core/schemas/providers/cohere/chat.go index ddada1d55..ae7765d8f 100644 --- a/core/schemas/providers/cohere/chat.go +++ b/core/schemas/providers/cohere/chat.go @@ -178,13 +178,14 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Cohe } // ToBifrostChatResponse converts a Cohere v2 response to Bifrost format -func (response *CohereChatResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse { +func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse { if response == nil { return nil } bifrostResponse := &schemas.BifrostChatResponse{ ID: response.ID, + Model: model, Object: "chat.completion", Choices: []schemas.BifrostResponseChoice{ { @@ -323,3 +324,156 @@ func (response *CohereChatResponse) ToBifrostChatResponse() *schemas.BifrostChat return bifrostResponse } + +func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) { + switch chunk.Type { + case StreamEventMessageStart: + if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil { + // Create streaming response for this delta + streamResponse := &schemas.BifrostChatResponse{ + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{ + Role: chunk.Delta.Message.Role, + }, + }, + }, + }, + } + + return streamResponse, nil, false + } + + case StreamEventContentDelta: + if chunk.Delta != nil && + chunk.Delta.Message != nil && + chunk.Delta.Message.Content != nil && + chunk.Delta.Message.Content.CohereStreamContentObject != nil && + chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil { + // Try to cast content to CohereStreamContent + streamResponse := &schemas.BifrostChatResponse{ + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{ + Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text, + }, + }, + }, + }, + } + + return streamResponse, nil, false + } + + case StreamEventToolPlanDelta: + if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil { + streamResponse := &schemas.BifrostChatResponse{ + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{ + Thought: chunk.Delta.Message.ToolPlan, + }, + }, + }, + }, + } + + return streamResponse, nil, false + } + + case StreamEventContentStart: + // Content start event - just continue, actual content comes in content-delta + return nil, nil, false + + case StreamEventToolCallStart, StreamEventToolCallDelta: + if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil { + // Handle single tool call object (tool-call-start/delta events) + cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject + toolCall := schemas.ChatAssistantMessageToolCall{} + + if cohereToolCall.ID != nil { + toolCall.ID = cohereToolCall.ID + } + + if cohereToolCall.Function != nil { + if cohereToolCall.Function.Name != nil { + toolCall.Function.Name = cohereToolCall.Function.Name + } + toolCall.Function.Arguments = cohereToolCall.Function.Arguments + } + + streamResponse := &schemas.BifrostChatResponse{ + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall}, + }, + }, + }, + }, + } + + return streamResponse, nil, false + } + + case StreamEventToolCallEnd: + return nil, nil, false + + case StreamEventContentEnd: + return nil, nil, false + + case StreamEventMessageEnd: + if chunk.Delta != nil { + var finishReason *string + usage := &schemas.BifrostLLMUsage{} + // Set finish reason + if chunk.Delta.FinishReason != nil { + finishReason = schemas.Ptr(string(*chunk.Delta.FinishReason)) + } + + // Set usage information + if chunk.Delta.Usage != nil { + if chunk.Delta.Usage.Tokens != nil { + if chunk.Delta.Usage.Tokens.InputTokens != nil { + usage.PromptTokens = int(*chunk.Delta.Usage.Tokens.InputTokens) + } + if chunk.Delta.Usage.Tokens.OutputTokens != nil { + usage.CompletionTokens = int(*chunk.Delta.Usage.Tokens.OutputTokens) + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + } + + streamResponse := &schemas.BifrostChatResponse{ + Object: "chat.completion.chunk", + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + FinishReason: finishReason, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{}, + }, + }, + }, + Usage: usage, + } + + return streamResponse, nil, true + } + return nil, nil, false + } + + return nil, nil, false +} diff --git a/core/schemas/providers/cohere/responses.go b/core/schemas/providers/cohere/responses.go index 000acc797..180839567 100644 --- a/core/schemas/providers/cohere/responses.go +++ b/core/schemas/providers/cohere/responses.go @@ -91,8 +91,8 @@ func ToCohereResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Cohe return cohereReq } -// ToResponsesBifrostResponse converts CohereChatResponse to BifrostResponse (Responses structure) -func (response *CohereChatResponse) ToResponsesBifrostResponsesResponse() *schemas.BifrostResponsesResponse { +// ToBifrostResponsesResponse converts CohereChatResponse to BifrostResponse (Responses structure) +func (response *CohereChatResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse { if response == nil { return nil } diff --git a/docs/changelogs/v1.3.13.mdx b/docs/changelogs/v1.3.13.mdx index eb40c17f3..649991dfa 100644 --- a/docs/changelogs/v1.3.13.mdx +++ b/docs/changelogs/v1.3.13.mdx @@ -6,11 +6,14 @@ description: "v1.3.13 changelog" - chore: version update framework to 1.1.18 and core to 1.2.16 - Adds env variable support for postgres config +- feat: standardize finish reason and single response handling across providers +- feat: provider config hot reloading added (no need to restart Bifrost after updating provider configs now) -- feat: Adds support for models list for providers +- feat: standardize finish reason and single response handling across providers +- feat: provider config hot reloading added diff --git a/framework/changelog.md b/framework/changelog.md index 9a51bc0ed..626e93061 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -1,5 +1,4 @@ -- Adds env variable resolution for postgres config - chore: Upgrades core to 1.2.16 \ No newline at end of file diff --git a/transports/changelog.md b/transports/changelog.md index 1128e71a4..8427d5551 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -2,4 +2,6 @@ - chore: version update framework to 1.1.18 and core to 1.2.16 -- Adds env variable support for postgres config \ No newline at end of file +- feat: Use all keys for list models request +- fix: handled panic when using gemini models with openai integration responses API requests +- chore: Added id, object, and model fields to Chat Completion responses from Bedrock and Cohere providers \ No newline at end of file