Skip to content

Commit ba3fe92

Browse files
move all providers to use fasthttp even for streaming
1 parent 6d7517f commit ba3fe92

File tree

13 files changed

+366
-603
lines changed

13 files changed

+366
-603
lines changed

core/providers/anthropic/anthropic.go

Lines changed: 44 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@ package anthropic
22

33
import (
44
"bufio"
5-
"bytes"
65
"context"
76
"errors"
87
"fmt"
9-
"io"
108
"net/http"
119
"strings"
1210
"sync"
@@ -22,7 +20,6 @@ import (
2220
type AnthropicProvider struct {
2321
logger schemas.Logger // Logger for provider operations
2422
client *fasthttp.Client // HTTP client for API requests
25-
streamClient *http.Client // HTTP client for streaming requests
2623
apiVersion string // API version for the provider
2724
networkConfig schemas.NetworkConfig // Network configuration including extra headers
2825
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
@@ -83,11 +80,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
8380
MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency,
8481
}
8582

86-
// Initialize streaming HTTP client
87-
streamClient := &http.Client{
88-
Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
89-
}
90-
83+
9184
// Pre-warm response pools
9285
for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ {
9386
anthropicTextResponsePool.Put(&AnthropicTextResponse{})
@@ -106,7 +99,6 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
10699
return &AnthropicProvider{
107100
logger: logger,
108101
client: client,
109-
streamClient: streamClient,
110102
apiVersion: "2023-06-01",
111103
networkConfig: config.NetworkConfig,
112104
sendBackRawResponse: config.SendBackRawResponse,
@@ -120,13 +112,10 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
120112
}
121113

122114
// parseStreamAnthropicError parses Anthropic streaming error responses.
123-
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
124-
statusCode := resp.StatusCode
125-
body, _ := io.ReadAll(resp.Body)
126-
resp.Body.Close()
127-
115+
func parseStreamAnthropicError(resp *fasthttp.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
116+
statusCode := resp.StatusCode()
128117
var errorResp AnthropicError
129-
if err := sonic.Unmarshal(body, &errorResp); err != nil {
118+
if err := sonic.Unmarshal(resp.Body(), &errorResp); err != nil {
130119
return providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
131120
}
132121

@@ -378,7 +367,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos
378367
// Use shared Anthropic streaming logic
379368
return HandleAnthropicChatCompletionStreaming(
380369
ctx,
381-
provider.streamClient,
370+
provider.client,
382371
provider.networkConfig.BaseURL+"/v1/messages",
383372
reqBody,
384373
headers,
@@ -394,7 +383,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos
394383
// This shared function reduces code duplication between providers that use the same SSE event format.
395384
func HandleAnthropicChatCompletionStreaming(
396385
ctx context.Context,
397-
httpClient *http.Client,
386+
client *fasthttp.Client,
398387
url string,
399388
requestBody interface{},
400389
headers map[string]string,
@@ -409,36 +398,28 @@ func HandleAnthropicChatCompletionStreaming(
409398
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerType)
410399
}
411400

412-
// Create HTTP request for streaming
413-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
414-
if err != nil {
415-
if errors.Is(err, context.Canceled) {
416-
return nil, &schemas.BifrostError{
417-
IsBifrostError: false,
418-
Error: &schemas.ErrorField{
419-
Type: schemas.Ptr(schemas.RequestCancelled),
420-
Message: schemas.ErrRequestCancelled,
421-
Error: err,
422-
},
423-
}
424-
}
425-
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
426-
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerType)
427-
}
428-
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequest, err, providerType)
429-
}
401+
req := fasthttp.AcquireRequest()
402+
resp := fasthttp.AcquireResponse()
403+
defer fasthttp.ReleaseRequest(req)
404+
defer fasthttp.ReleaseResponse(resp)
405+
406+
req.Header.SetMethod(http.MethodPost)
407+
req.SetRequestURI(url)
408+
req.Header.SetContentType("application/json")
409+
providerUtils.SetExtraHeaders(req, extraHeaders, nil)
430410

431411
// Set headers
432412
for key, value := range headers {
433413
req.Header.Set(key, value)
434414
}
435415

436-
// Set any extra headers from network config
437-
providerUtils.SetExtraHeadersHTTP(req, extraHeaders, nil)
416+
req.SetBody(jsonBody)
417+
438418

439419
// Make the request
440-
resp, err := httpClient.Do(req)
441-
if err != nil {
420+
// Make the request
421+
_, bifrostErr := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
422+
if bifrostErr != nil {
442423
if errors.Is(err, context.Canceled) {
443424
return nil, &schemas.BifrostError{
444425
IsBifrostError: false,
@@ -449,14 +430,14 @@ func HandleAnthropicChatCompletionStreaming(
449430
},
450431
}
451432
}
452-
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
433+
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
453434
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerType)
454435
}
455436
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequest, err, providerType)
456437
}
457438

458439
// Check for HTTP errors
459-
if resp.StatusCode != http.StatusOK {
440+
if resp.StatusCode() != fasthttp.StatusOK {
460441
return nil, parseStreamAnthropicError(resp, providerType)
461442
}
462443

@@ -466,9 +447,8 @@ func HandleAnthropicChatCompletionStreaming(
466447
// Start streaming in a goroutine
467448
go func() {
468449
defer close(responseChan)
469-
defer resp.Body.Close()
470-
471-
scanner := bufio.NewScanner(resp.Body)
450+
451+
scanner := bufio.NewScanner(resp.BodyStream())
472452
chunkIndex := 0
473453

474454
startTime := time.Now()
@@ -659,35 +639,30 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook
659639
}
660640

661641
// Create HTTP request for streaming
662-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v1/messages", bytes.NewReader(jsonBody))
663-
if err != nil {
664-
if errors.Is(err, context.Canceled) {
665-
return nil, &schemas.BifrostError{
666-
IsBifrostError: false,
667-
Error: &schemas.ErrorField{
668-
Type: schemas.Ptr(schemas.RequestCancelled),
669-
Message: schemas.ErrRequestCancelled,
670-
Error: err,
671-
},
672-
}
673-
}
674-
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
675-
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey())
676-
}
677-
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequest, err, provider.GetProviderKey())
678-
}
642+
req := fasthttp.AcquireRequest()
643+
resp := fasthttp.AcquireResponse()
644+
defer fasthttp.ReleaseRequest(req)
645+
defer fasthttp.ReleaseResponse(resp)
646+
647+
url := fmt.Sprintf("%s/v1/messages", provider.networkConfig.BaseURL)
648+
649+
req.Header.SetMethod(http.MethodPost)
650+
req.SetRequestURI(url)
651+
req.Header.SetContentType("application/json")
679652

680653
// Set headers
681654
for key, value := range headers {
682655
req.Header.Set(key, value)
683656
}
684-
685657
// Set any extra headers from network config
686-
providerUtils.SetExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)
658+
providerUtils.SetExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
659+
// Set body
660+
req.SetBody(jsonBody)
687661

688662
// Make the request
689-
resp, err := provider.streamClient.Do(req)
690-
if err != nil {
663+
// Make the request
664+
_, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
665+
if bifrostErr != nil {
691666
if errors.Is(err, context.Canceled) {
692667
return nil, &schemas.BifrostError{
693668
IsBifrostError: false,
@@ -698,14 +673,14 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook
698673
},
699674
}
700675
}
701-
if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) {
676+
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
702677
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey())
703678
}
704679
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequest, err, provider.GetProviderKey())
705680
}
706681

707682
// Check for HTTP errors
708-
if resp.StatusCode != http.StatusOK {
683+
if resp.StatusCode() != fasthttp.StatusOK {
709684
return nil, parseStreamAnthropicError(resp, provider.GetProviderKey())
710685
}
711686

@@ -715,9 +690,8 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook
715690
// Start streaming in a goroutine
716691
go func() {
717692
defer close(responseChan)
718-
defer resp.Body.Close()
719-
720-
scanner := bufio.NewScanner(resp.Body)
693+
694+
scanner := bufio.NewScanner(resp.BodyStream())
721695
chunkIndex := 0
722696

723697
startTime := time.Now()

core/providers/azure/azure.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ const AzureAuthorizationTokenKey schemas.BifrostContextKey = "azure-authorizatio
2121
// AzureProvider implements the Provider interface for Azure's OpenAI API.
2222
type AzureProvider struct {
2323
logger schemas.Logger // Logger for provider operations
24-
client *fasthttp.Client // HTTP client for API requests
25-
streamClient *http.Client // HTTP client for streaming requests
24+
client *fasthttp.Client // HTTP client for API requests
2625
networkConfig schemas.NetworkConfig // Network configuration including extra headers
2726
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
2827
}
@@ -39,18 +38,12 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
3938
MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency,
4039
}
4140

42-
// Initialize streaming HTTP client
43-
streamClient := &http.Client{
44-
Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
45-
}
46-
4741
// Configure proxy if provided
4842
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
4943

5044
return &AzureProvider{
5145
logger: logger,
5246
client: client,
53-
streamClient: streamClient,
5447
networkConfig: config.NetworkConfig,
5548
sendBackRawResponse: config.SendBackRawResponse,
5649
}, nil
@@ -279,7 +272,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo
279272

280273
return openai.HandleOpenAITextCompletionStreaming(
281274
ctx,
282-
provider.streamClient,
275+
provider.client,
283276
url,
284277
request,
285278
authHeader,
@@ -364,7 +357,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
364357
// Use shared streaming logic from OpenAI
365358
return openai.HandleOpenAIChatCompletionStreaming(
366359
ctx,
367-
provider.streamClient,
360+
provider.client,
368361
url,
369362
request,
370363
authHeader,
@@ -489,7 +482,7 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
489482
// Use shared streaming logic from OpenAI
490483
return openai.HandleOpenAIResponsesStreaming(
491484
ctx,
492-
provider.streamClient,
485+
provider.client,
493486
url,
494487
request,
495488
authHeader,

core/providers/cerebras.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package providers
44

55
import (
66
"context"
7-
"net/http"
87
"strings"
98
"time"
109

@@ -17,8 +16,7 @@ import (
1716
// CerebrasProvider implements the Provider interface for Cerebras's API.
1817
type CerebrasProvider struct {
1918
logger schemas.Logger // Logger for provider operations
20-
client *fasthttp.Client // HTTP client for API requests
21-
streamClient *http.Client // HTTP client for streaming requests
19+
client *fasthttp.Client // HTTP client for API requests
2220
networkConfig schemas.NetworkConfig // Network configuration including extra headers
2321
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
2422
}
@@ -35,11 +33,6 @@ func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger)
3533
MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize,
3634
}
3735

38-
// Initialize streaming HTTP client
39-
streamClient := &http.Client{
40-
Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
41-
}
42-
4336
// Configure proxy if provided
4437
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
4538

@@ -51,8 +44,7 @@ func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger)
5144

5245
return &CerebrasProvider{
5346
logger: logger,
54-
client: client,
55-
streamClient: streamClient,
47+
client: client,
5648
networkConfig: config.NetworkConfig,
5749
sendBackRawResponse: config.SendBackRawResponse,
5850
}, nil
@@ -91,7 +83,7 @@ func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schema
9183
func (provider *CerebrasProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
9284
return openai.HandleOpenAITextCompletionStreaming(
9385
ctx,
94-
provider.streamClient,
86+
provider.client,
9587
provider.networkConfig.BaseURL+"/v1/completions",
9688
request,
9789
map[string]string{"Authorization": "Bearer " + key.Value},
@@ -126,7 +118,7 @@ func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, post
126118
// Use shared OpenAI-compatible streaming logic
127119
return openai.HandleOpenAIChatCompletionStreaming(
128120
ctx,
129-
provider.streamClient,
121+
provider.client,
130122
provider.networkConfig.BaseURL+"/v1/chat/completions",
131123
request,
132124
map[string]string{"Authorization": "Bearer " + key.Value},

0 commit comments

Comments
 (0)