Skip to content

Commit 555c4fc

Browse files
feat: core extended
1 parent 5e6bb75 commit 555c4fc

File tree

23 files changed

+213
-165
lines changed

23 files changed

+213
-165
lines changed

core/bifrost.go

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type Bifrost struct {
4646
logger schemas.Logger // logger instance, default logger is used if not provided
4747
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
4848
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
49+
keySelector schemas.KeySelector // Custom key selector function
4950
}
5051

5152
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -86,10 +87,15 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8687
plugins: atomic.Pointer[[]schemas.Plugin]{},
8788
requestQueues: sync.Map{},
8889
waitGroups: sync.Map{},
90+
keySelector: config.KeySelector,
8991
}
9092
bifrost.plugins.Store(&config.Plugins)
9193
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
9294

95+
if bifrost.keySelector == nil {
96+
bifrost.keySelector = WeightedRandomKeySelector
97+
}
98+
9399
// Initialize object pools
94100
bifrost.channelMessagePool = sync.Pool{
95101
New: func() interface{} {
@@ -632,12 +638,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
632638
return bifrost.prepareProvider(providerKey, providerConfig)
633639
}
634640

635-
oldQueue := oldQueueValue.(chan ChannelMessage)
641+
oldQueue := oldQueueValue.(chan *ChannelMessage)
636642

637643
bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey)
638644

639645
// Step 1: Create new queue with updated buffer size
640-
newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
646+
newQueue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
641647

642648
// Step 2: Transfer any buffered requests from old queue to new queue
643649
// This prevents request loss during the transition
@@ -653,7 +659,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
653659
// New queue is full, handle this request in a goroutine
654660
// This is unlikely with proper buffer sizing but provides safety
655661
transferWaitGroup.Add(1)
656-
go func(m ChannelMessage) {
662+
go func(m *ChannelMessage) {
657663
defer transferWaitGroup.Done()
658664
select {
659665
case newQueue <- m:
@@ -1017,7 +1023,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10171023
return fmt.Errorf("failed to get config for provider: %v", err)
10181024
}
10191025

1020-
queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
1026+
queue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
10211027

10221028
bifrost.requestQueues.Store(providerKey, queue)
10231029

@@ -1044,13 +1050,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10441050
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
10451051
// given the provider config is provided in the account interface implementation.
10461052
// This function uses read locks to prevent race conditions during provider updates.
1047-
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
1053+
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan *ChannelMessage, error) {
10481054
// Use read lock to allow concurrent reads but prevent concurrent updates
10491055
providerMutex := bifrost.getProviderMutex(providerKey)
10501056
providerMutex.RLock()
10511057

10521058
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
1053-
queue := queueValue.(chan ChannelMessage)
1059+
queue := queueValue.(chan *ChannelMessage)
10541060
providerMutex.RUnlock()
10551061
return queue, nil
10561062
}
@@ -1063,7 +1069,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10631069

10641070
// Double-check after acquiring write lock (another goroutine might have created it)
10651071
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
1066-
queue := queueValue.(chan ChannelMessage)
1072+
queue := queueValue.(chan *ChannelMessage)
10671073
return queue, nil
10681074
}
10691075

@@ -1079,7 +1085,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10791085
}
10801086

10811087
queueValue, _ := bifrost.requestQueues.Load(providerKey)
1082-
queue := queueValue.(chan ChannelMessage)
1088+
queue := queueValue.(chan *ChannelMessage)
10831089

10841090
return queue, nil
10851091
}
@@ -1341,9 +1347,8 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13411347

13421348
msg := bifrost.getChannelMessage(*preReq)
13431349
msg.Context = ctx
1344-
startTime := time.Now()
13451350
select {
1346-
case queue <- *msg:
1351+
case queue <- msg:
13471352
// Message was sent successfully
13481353
case <-ctx.Done():
13491354
bifrost.releaseChannelMessage(msg)
@@ -1355,7 +1360,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13551360
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
13561361
}
13571362
select {
1358-
case queue <- *msg:
1363+
case queue <- msg:
13591364
// Message was sent successfully
13601365
case <-ctx.Done():
13611366
bifrost.releaseChannelMessage(msg)
@@ -1368,11 +1373,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13681373
pluginCount := len(*bifrost.plugins.Load())
13691374
select {
13701375
case result = <-msg.Response:
1371-
latency := time.Since(startTime).Milliseconds()
1372-
if result.ExtraFields.Latency == nil {
1373-
result.ExtraFields.Latency = Ptr(float64(latency))
1374-
}
1375-
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, pluginCount)
1376+
resp, bifrostErr := pipeline.RunPostHooks(&msg.Context, result, nil, pluginCount)
13761377
if bifrostErr != nil {
13771378
bifrost.releaseChannelMessage(msg)
13781379
return nil, bifrostErr
@@ -1381,7 +1382,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13811382
return resp, nil
13821383
case bifrostErrVal := <-msg.Err:
13831384
bifrostErrPtr := &bifrostErrVal
1384-
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, pluginCount)
1385+
resp, bifrostErrPtr = pipeline.RunPostHooks(&msg.Context, nil, bifrostErrPtr, pluginCount)
13851386
bifrost.releaseChannelMessage(msg)
13861387
if bifrostErrPtr != nil {
13871388
return nil, bifrostErrPtr
@@ -1463,7 +1464,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14631464
msg.Context = ctx
14641465

14651466
select {
1466-
case queue <- *msg:
1467+
case queue <- msg:
14671468
// Message was sent successfully
14681469
case <-ctx.Done():
14691470
bifrost.releaseChannelMessage(msg)
@@ -1475,7 +1476,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14751476
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
14761477
}
14771478
select {
1478-
case queue <- *msg:
1479+
case queue <- msg:
14791480
// Message was sent successfully
14801481
case <-ctx.Done():
14811482
bifrost.releaseChannelMessage(msg)
@@ -1506,7 +1507,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
15061507

15071508
// requestWorker handles incoming requests from the queue for a specific provider.
15081509
// It manages retries, error handling, and response processing.
1509-
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan ChannelMessage) {
1510+
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan *ChannelMessage) {
15101511
defer func() {
15111512
if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok {
15121513
waitGroup := waitGroupValue.(*sync.WaitGroup)
@@ -1541,6 +1542,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15411542
}
15421543
continue
15431544
}
1545+
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKey, key.ID)
15441546
}
15451547

15461548
// Track attempts
@@ -1576,12 +1578,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15761578

15771579
// Attempt the request
15781580
if IsStreamRequestType(req.RequestType) {
1579-
stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner)
1581+
stream, bifrostError = handleProviderStreamRequest(provider, req, key, postHookRunner)
15801582
if bifrostError != nil && !bifrostError.IsBifrostError {
15811583
break // Don't retry client errors
15821584
}
15831585
} else {
1584-
result, bifrostError = handleProviderRequest(provider, &req, key)
1586+
result, bifrostError = handleProviderRequest(provider, req, key)
15851587
if bifrostError != nil {
15861588
break // Don't retry client errors
15871589
}
@@ -1930,9 +1932,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19301932
return supportedKeys[0], nil
19311933
}
19321934

1935+
selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model)
1936+
if err != nil {
1937+
return schemas.Key{}, err
1938+
}
1939+
1940+
return selectedKey, nil
1941+
1942+
}
1943+
1944+
func WeightedRandomKeySelector(ctx *context.Context, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
19331945
// Use a weighted random selection based on key weights
19341946
totalWeight := 0
1935-
for _, key := range supportedKeys {
1947+
for _, key := range keys {
19361948
totalWeight += int(key.Weight * 100) // Convert float to int for better performance
19371949
}
19381950

@@ -1942,15 +1954,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19421954

19431955
// Select key based on weight
19441956
currentWeight := 0
1945-
for _, key := range supportedKeys {
1957+
for _, key := range keys {
19461958
currentWeight += int(key.Weight * 100)
19471959
if randomValue < currentWeight {
19481960
return key, nil
19491961
}
19501962
}
19511963

19521964
// Fallback to first key if something goes wrong
1953-
return supportedKeys[0], nil
1965+
return keys[0], nil
19541966
}
19551967

19561968
// Shutdown gracefully stops all workers when triggered.
@@ -1960,7 +1972,7 @@ func (bifrost *Bifrost) Shutdown() {
19601972

19611973
// Close all provider queues to signal workers to stop
19621974
bifrost.requestQueues.Range(func(key, value interface{}) bool {
1963-
close(value.(chan ChannelMessage))
1975+
close(value.(chan *ChannelMessage))
19641976
return true
19651977
})
19661978

core/providers/anthropic.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
124124
// completeRequest sends a request to Anthropic's API and handles the response.
125125
// It constructs the API URL, sets up authentication, and processes the response.
126126
// Returns the response body or an error if the request fails.
127-
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, *schemas.BifrostError) {
127+
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) {
128128
// Marshal the request body
129129
jsonData, err := sonic.Marshal(requestBody)
130130
if err != nil {
131-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
131+
return nil, 0, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
132132
}
133133

134134
// Create the request with the JSON body
@@ -149,9 +149,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
149149
req.SetBody(jsonData)
150150

151151
// Send the request
152-
bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
152+
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
153153
if bifrostErr != nil {
154-
return nil, bifrostErr
154+
return nil, latency, bifrostErr
155155
}
156156

157157
// Handle error response
@@ -164,13 +164,14 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
164164
bifrostErr.Error.Type = &errorResp.Error.Type
165165
bifrostErr.Error.Message = errorResp.Error.Message
166166

167-
return nil, bifrostErr
167+
return nil, latency, bifrostErr
168168
}
169169

170-
// Read the response body
171-
body := resp.Body()
170+
// Read the response body and copy it before releasing the response
171+
// to avoid use-after-free since resp.Body() references fasthttp's internal buffer
172+
bodyCopy := append([]byte(nil), resp.Body()...)
172173

173-
return body, nil
174+
return bodyCopy, latency, nil
174175
}
175176

176177
// TextCompletion performs a text completion request to Anthropic's API.
@@ -188,7 +189,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
188189
}
189190

190191
// Use struct directly for JSON marshaling
191-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
192+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
192193
if err != nil {
193194
return nil, err
194195
}
@@ -208,6 +209,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
208209
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
209210
bifrostResponse.ExtraFields.ModelRequested = request.Model
210211
bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
212+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
211213

212214
// Set raw response if enabled
213215
if provider.sendBackRawResponse {
@@ -239,7 +241,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
239241
}
240242

241243
// Use struct directly for JSON marshaling
242-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
244+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
243245
if err != nil {
244246
return nil, err
245247
}
@@ -260,6 +262,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
260262
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
261263
bifrostResponse.ExtraFields.ModelRequested = request.Model
262264
bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest
265+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
263266

264267
// Set raw response if enabled
265268
if provider.sendBackRawResponse {
@@ -284,7 +287,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
284287
}
285288

286289
// Use struct directly for JSON marshaling
287-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
290+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
288291
if err != nil {
289292
return nil, err
290293
}
@@ -305,6 +308,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
305308
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
306309
bifrostResponse.ExtraFields.ModelRequested = request.Model
307310
bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest
311+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
308312

309313
// Set raw response if enabled
310314
if provider.sendBackRawResponse {

0 commit comments

Comments
 (0)