Skip to content

Commit 34c87f1

Browse files
feat: core extended
1 parent c6170cd commit 34c87f1

File tree

17 files changed

+174
-116
lines changed

17 files changed

+174
-116
lines changed

core/bifrost.go

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,20 @@ type ChannelMessage struct {
3232
// It handles request routing, provider management, and response processing.
3333
type Bifrost struct {
3434
ctx context.Context
35-
account schemas.Account // account interface
36-
plugins []schemas.Plugin // list of plugins
37-
requestQueues sync.Map // provider request queues (thread-safe)
38-
waitGroups sync.Map // wait groups for each provider (thread-safe)
39-
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
40-
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
41-
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
42-
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
43-
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
44-
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
45-
logger schemas.Logger // logger instance, default logger is used if not provided
46-
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
47-
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
35+
account schemas.Account // account interface
36+
plugins []schemas.Plugin // list of plugins
37+
requestQueues sync.Map // provider request queues (thread-safe)
38+
waitGroups sync.Map // wait groups for each provider (thread-safe)
39+
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
40+
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
41+
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
42+
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
43+
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
44+
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
45+
logger schemas.Logger // logger instance, default logger is used if not provided
46+
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
47+
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
48+
keySelector schemas.KeySelector // Custom key selector function
4849
}
4950

5051
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -85,9 +86,14 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8586
plugins: config.Plugins,
8687
requestQueues: sync.Map{},
8788
waitGroups: sync.Map{},
89+
keySelector: config.KeySelector,
8890
}
8991
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
9092

93+
if bifrost.keySelector == nil {
94+
bifrost.keySelector = WeightedRandomKeySelector
95+
}
96+
9197
// Initialize object pools
9298
bifrost.channelMessagePool = sync.Pool{
9399
New: func() interface{} {
@@ -328,12 +334,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
328334
return bifrost.prepareProvider(providerKey, providerConfig)
329335
}
330336

331-
oldQueue := oldQueueValue.(chan ChannelMessage)
337+
oldQueue := oldQueueValue.(chan *ChannelMessage)
332338

333339
bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey)
334340

335341
// Step 1: Create new queue with updated buffer size
336-
newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
342+
newQueue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
337343

338344
// Step 2: Transfer any buffered requests from old queue to new queue
339345
// This prevents request loss during the transition
@@ -349,7 +355,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
349355
// New queue is full, handle this request in a goroutine
350356
// This is unlikely with proper buffer sizing but provides safety
351357
transferWaitGroup.Add(1)
352-
go func(m ChannelMessage) {
358+
go func(m *ChannelMessage) {
353359
defer transferWaitGroup.Done()
354360
select {
355361
case newQueue <- m:
@@ -712,7 +718,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
712718
return fmt.Errorf("failed to get config for provider: %v", err)
713719
}
714720

715-
queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
721+
queue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
716722

717723
bifrost.requestQueues.Store(providerKey, queue)
718724

@@ -738,13 +744,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
738744
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
739745
// given the provider config is provided in the account interface implementation.
740746
// This function uses read locks to prevent race conditions during provider updates.
741-
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
747+
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan *ChannelMessage, error) {
742748
// Use read lock to allow concurrent reads but prevent concurrent updates
743749
providerMutex := bifrost.getProviderMutex(providerKey)
744750
providerMutex.RLock()
745751

746752
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
747-
queue := queueValue.(chan ChannelMessage)
753+
queue := queueValue.(chan *ChannelMessage)
748754
providerMutex.RUnlock()
749755
return queue, nil
750756
}
@@ -757,7 +763,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
757763

758764
// Double-check after acquiring write lock (another goroutine might have created it)
759765
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
760-
queue := queueValue.(chan ChannelMessage)
766+
queue := queueValue.(chan *ChannelMessage)
761767
return queue, nil
762768
}
763769

@@ -773,7 +779,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
773779
}
774780

775781
queueValue, _ := bifrost.requestQueues.Load(providerKey)
776-
queue := queueValue.(chan ChannelMessage)
782+
queue := queueValue.(chan *ChannelMessage)
777783

778784
return queue, nil
779785
}
@@ -992,7 +998,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
992998
msg.Context = ctx
993999

9941000
select {
995-
case queue <- *msg:
1001+
case queue <- msg:
9961002
// Message was sent successfully
9971003
case <-ctx.Done():
9981004
bifrost.releaseChannelMessage(msg)
@@ -1004,7 +1010,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10041010
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
10051011
}
10061012
select {
1007-
case queue <- *msg:
1013+
case queue <- msg:
10081014
// Message was sent successfully
10091015
case <-ctx.Done():
10101016
bifrost.releaseChannelMessage(msg)
@@ -1016,7 +1022,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10161022
var resp *schemas.BifrostResponse
10171023
select {
10181024
case result = <-msg.Response:
1019-
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
1025+
resp, bifrostErr := pipeline.RunPostHooks(&msg.Context, result, nil, len(bifrost.plugins))
10201026
if bifrostErr != nil {
10211027
bifrost.releaseChannelMessage(msg)
10221028
return nil, bifrostErr
@@ -1025,7 +1031,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10251031
return resp, nil
10261032
case bifrostErrVal := <-msg.Err:
10271033
bifrostErrPtr := &bifrostErrVal
1028-
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
1034+
resp, bifrostErrPtr = pipeline.RunPostHooks(&msg.Context, nil, bifrostErrPtr, len(bifrost.plugins))
10291035
bifrost.releaseChannelMessage(msg)
10301036
if bifrostErrPtr != nil {
10311037
return nil, bifrostErrPtr
@@ -1110,7 +1116,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11101116
msg.Context = ctx
11111117

11121118
select {
1113-
case queue <- *msg:
1119+
case queue <- msg:
11141120
// Message was sent successfully
11151121
case <-ctx.Done():
11161122
bifrost.releaseChannelMessage(msg)
@@ -1122,7 +1128,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11221128
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
11231129
}
11241130
select {
1125-
case queue <- *msg:
1131+
case queue <- msg:
11261132
// Message was sent successfully
11271133
case <-ctx.Done():
11281134
bifrost.releaseChannelMessage(msg)
@@ -1142,7 +1148,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11421148

11431149
// requestWorker handles incoming requests from the queue for a specific provider.
11441150
// It manages retries, error handling, and response processing.
1145-
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan ChannelMessage) {
1151+
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan *ChannelMessage) {
11461152
defer func() {
11471153
if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok {
11481154
waitGroup := waitGroupValue.(*sync.WaitGroup)
@@ -1177,6 +1183,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
11771183
}
11781184
continue
11791185
}
1186+
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKey, key.ID)
11801187
}
11811188

11821189
// Track attempts
@@ -1212,12 +1219,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
12121219

12131220
// Attempt the request
12141221
if IsStreamRequestType(req.Type) {
1215-
stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner, req.Type)
1222+
stream, bifrostError = handleProviderStreamRequest(provider, req, key, postHookRunner, req.Type)
12161223
if bifrostError != nil && !bifrostError.IsBifrostError {
12171224
break // Don't retry client errors
12181225
}
12191226
} else {
1220-
result, bifrostError = handleProviderRequest(provider, &req, key, req.Type)
1227+
result, bifrostError = handleProviderRequest(provider, req, key, req.Type)
12211228
if bifrostError != nil {
12221229
break // Don't retry client errors
12231230
}
@@ -1518,9 +1525,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
15181525
return supportedKeys[0], nil
15191526
}
15201527

1528+
selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model)
1529+
if err != nil {
1530+
return schemas.Key{}, err
1531+
}
1532+
1533+
return selectedKey, nil
1534+
1535+
}
1536+
1537+
func WeightedRandomKeySelector(ctx *context.Context, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
15211538
// Use a weighted random selection based on key weights
15221539
totalWeight := 0
1523-
for _, key := range supportedKeys {
1540+
for _, key := range keys {
15241541
totalWeight += int(key.Weight * 100) // Convert float to int for better performance
15251542
}
15261543

@@ -1530,15 +1547,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
15301547

15311548
// Select key based on weight
15321549
currentWeight := 0
1533-
for _, key := range supportedKeys {
1550+
for _, key := range keys {
15341551
currentWeight += int(key.Weight * 100)
15351552
if randomValue < currentWeight {
15361553
return key, nil
15371554
}
15381555
}
15391556

15401557
// Fallback to first key if something goes wrong
1541-
return supportedKeys[0], nil
1558+
return keys[0], nil
15421559
}
15431560

15441561
// Shutdown gracefully stops all workers when triggered.
@@ -1548,7 +1565,7 @@ func (bifrost *Bifrost) Shutdown() {
15481565

15491566
// Close all provider queues to signal workers to stop
15501567
bifrost.requestQueues.Range(func(key, value interface{}) bool {
1551-
close(value.(chan ChannelMessage))
1568+
close(value.(chan *ChannelMessage))
15521569
return true
15531570
})
15541571

core/providers/anthropic.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,11 @@ func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string
280280
// completeRequest sends a request to Anthropic's API and handles the response.
281281
// It constructs the API URL, sets up authentication, and processes the response.
282282
// Returns the response body or an error if the request fails.
283-
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) {
283+
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) {
284284
// Marshal the request body
285285
jsonData, err := sonic.Marshal(requestBody)
286286
if err != nil {
287-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
287+
return nil, 0, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
288288
}
289289

290290
// Create the request with the JSON body
@@ -305,9 +305,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
305305
req.SetBody(jsonData)
306306

307307
// Send the request
308-
bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
308+
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
309309
if bifrostErr != nil {
310-
return nil, bifrostErr
310+
return nil, latency, bifrostErr
311311
}
312312

313313
// Handle error response
@@ -320,13 +320,13 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
320320
bifrostErr.Error.Type = &errorResp.Error.Type
321321
bifrostErr.Error.Message = errorResp.Error.Message
322322

323-
return nil, bifrostErr
323+
return nil, latency, bifrostErr
324324
}
325325

326326
// Read the response body
327327
body := resp.Body()
328328

329-
return body, nil
329+
return body, latency, nil
330330
}
331331

332332
// TextCompletion performs a text completion request to Anthropic's API.
@@ -345,7 +345,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model str
345345
"prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text),
346346
}, preparedParams)
347347

348-
responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
348+
responseBody, latency, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
349349
if err != nil {
350350
return nil, err
351351
}
@@ -383,6 +383,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model str
383383
Model: response.Model,
384384
ExtraFields: schemas.BifrostResponseExtraFields{
385385
Provider: provider.GetProviderKey(),
386+
Latency: latency.Milliseconds(),
386387
},
387388
}
388389

@@ -414,7 +415,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model str
414415
"messages": formattedMessages,
415416
}, preparedParams)
416417

417-
responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
418+
responseBody, latency, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
418419
if err != nil {
419420
return nil, err
420421
}
@@ -437,6 +438,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model str
437438

438439
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
439440
Provider: provider.GetProviderKey(),
441+
Latency: latency.Milliseconds(),
440442
}
441443

442444
// Set raw response if enabled

0 commit comments

Comments
 (0)