@@ -32,19 +32,20 @@ type ChannelMessage struct {
3232// It handles request routing, provider management, and response processing.
3333type Bifrost struct {
3434 ctx context.Context
35- account schemas.Account // account interface
36- plugins []schemas.Plugin // list of plugins
37- requestQueues sync.Map // provider request queues (thread-safe)
38- waitGroups sync.Map // wait groups for each provider (thread-safe)
39- providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
40- channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
41- responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
42- errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
43- responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
44- pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
45- logger schemas.Logger // logger instance, default logger is used if not provided
46- mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
47- dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
35+ account schemas.Account // account interface
36+ plugins []schemas.Plugin // list of plugins
37+ requestQueues sync.Map // provider request queues (thread-safe)
38+ waitGroups sync.Map // wait groups for each provider (thread-safe)
39+ providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
40+ channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
41+ responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
42+ errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
43+ responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
44+ pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
45+ logger schemas.Logger // logger instance, default logger is used if not provided
46+ mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
47+ dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
48+ keySelector schemas.KeySelector // Custom key selector function
4849}
4950
5051// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -85,9 +86,14 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8586 plugins : config .Plugins ,
8687 requestQueues : sync.Map {},
8788 waitGroups : sync.Map {},
89+ keySelector : config .KeySelector ,
8890 }
8991 bifrost .dropExcessRequests .Store (config .DropExcessRequests )
9092
93+ if bifrost .keySelector == nil {
94+ bifrost .keySelector = WeightedRandomKeySelector
95+ }
96+
9197 // Initialize object pools
9298 bifrost .channelMessagePool = sync.Pool {
9399 New : func () interface {} {
@@ -328,12 +334,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
328334 return bifrost .prepareProvider (providerKey , providerConfig )
329335 }
330336
331- oldQueue := oldQueueValue .(chan ChannelMessage )
337+ oldQueue := oldQueueValue .(chan * ChannelMessage )
332338
333339 bifrost .logger .Debug ("gracefully stopping existing workers for provider %s" , providerKey )
334340
335341 // Step 1: Create new queue with updated buffer size
336- newQueue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
342+ newQueue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
337343
338344 // Step 2: Transfer any buffered requests from old queue to new queue
339345 // This prevents request loss during the transition
@@ -349,7 +355,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
349355 // New queue is full, handle this request in a goroutine
350356 // This is unlikely with proper buffer sizing but provides safety
351357 transferWaitGroup .Add (1 )
352- go func (m ChannelMessage ) {
358+ go func (m * ChannelMessage ) {
353359 defer transferWaitGroup .Done ()
354360 select {
355361 case newQueue <- m :
@@ -712,7 +718,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
712718 return fmt .Errorf ("failed to get config for provider: %v" , err )
713719 }
714720
715- queue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
721+ queue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
716722
717723 bifrost .requestQueues .Store (providerKey , queue )
718724
@@ -738,13 +744,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
738744// If the queue doesn't exist, it creates one at runtime and initializes the provider,
739745// given the provider config is provided in the account interface implementation.
740746// This function uses read locks to prevent race conditions during provider updates.
741- func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan ChannelMessage , error ) {
747+ func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan * ChannelMessage , error ) {
742748 // Use read lock to allow concurrent reads but prevent concurrent updates
743749 providerMutex := bifrost .getProviderMutex (providerKey )
744750 providerMutex .RLock ()
745751
746752 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
747- queue := queueValue .(chan ChannelMessage )
753+ queue := queueValue .(chan * ChannelMessage )
748754 providerMutex .RUnlock ()
749755 return queue , nil
750756 }
@@ -757,7 +763,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
757763
758764 // Double-check after acquiring write lock (another goroutine might have created it)
759765 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
760- queue := queueValue .(chan ChannelMessage )
766+ queue := queueValue .(chan * ChannelMessage )
761767 return queue , nil
762768 }
763769
@@ -773,7 +779,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
773779 }
774780
775781 queueValue , _ := bifrost .requestQueues .Load (providerKey )
776- queue := queueValue .(chan ChannelMessage )
782+ queue := queueValue .(chan * ChannelMessage )
777783
778784 return queue , nil
779785}
@@ -992,7 +998,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
992998 msg .Context = ctx
993999
9941000 select {
995- case queue <- * msg :
1001+ case queue <- msg :
9961002 // Message was sent successfully
9971003 case <- ctx .Done ():
9981004 bifrost .releaseChannelMessage (msg )
@@ -1004,7 +1010,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10041010 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
10051011 }
10061012 select {
1007- case queue <- * msg :
1013+ case queue <- msg :
10081014 // Message was sent successfully
10091015 case <- ctx .Done ():
10101016 bifrost .releaseChannelMessage (msg )
@@ -1016,7 +1022,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10161022 var resp * schemas.BifrostResponse
10171023 select {
10181024 case result = <- msg .Response :
1019- resp , bifrostErr := pipeline .RunPostHooks (& ctx , result , nil , len (bifrost .plugins ))
1025+ resp , bifrostErr := pipeline .RunPostHooks (& msg . Context , result , nil , len (bifrost .plugins ))
10201026 if bifrostErr != nil {
10211027 bifrost .releaseChannelMessage (msg )
10221028 return nil , bifrostErr
@@ -1025,7 +1031,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
10251031 return resp , nil
10261032 case bifrostErrVal := <- msg .Err :
10271033 bifrostErrPtr := & bifrostErrVal
1028- resp , bifrostErrPtr = pipeline .RunPostHooks (& ctx , nil , bifrostErrPtr , len (bifrost .plugins ))
1034+ resp , bifrostErrPtr = pipeline .RunPostHooks (& msg . Context , nil , bifrostErrPtr , len (bifrost .plugins ))
10291035 bifrost .releaseChannelMessage (msg )
10301036 if bifrostErrPtr != nil {
10311037 return nil , bifrostErrPtr
@@ -1110,7 +1116,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11101116 msg .Context = ctx
11111117
11121118 select {
1113- case queue <- * msg :
1119+ case queue <- msg :
11141120 // Message was sent successfully
11151121 case <- ctx .Done ():
11161122 bifrost .releaseChannelMessage (msg )
@@ -1122,7 +1128,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11221128 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
11231129 }
11241130 select {
1125- case queue <- * msg :
1131+ case queue <- msg :
11261132 // Message was sent successfully
11271133 case <- ctx .Done ():
11281134 bifrost .releaseChannelMessage (msg )
@@ -1142,7 +1148,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
11421148
11431149// requestWorker handles incoming requests from the queue for a specific provider.
11441150// It manages retries, error handling, and response processing.
1145- func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan ChannelMessage ) {
1151+ func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan * ChannelMessage ) {
11461152 defer func () {
11471153 if waitGroupValue , ok := bifrost .waitGroups .Load (provider .GetProviderKey ()); ok {
11481154 waitGroup := waitGroupValue .(* sync.WaitGroup )
@@ -1177,6 +1183,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
11771183 }
11781184 continue
11791185 }
1186+ req .Context = context .WithValue (req .Context , schemas .BifrostContextKeySelectedKey , key .ID )
11801187 }
11811188
11821189 // Track attempts
@@ -1212,12 +1219,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
12121219
12131220 // Attempt the request
12141221 if IsStreamRequestType (req .Type ) {
1215- stream , bifrostError = handleProviderStreamRequest (provider , & req , key , postHookRunner , req .Type )
1222+ stream , bifrostError = handleProviderStreamRequest (provider , req , key , postHookRunner , req .Type )
12161223 if bifrostError != nil && ! bifrostError .IsBifrostError {
12171224 break // Don't retry client errors
12181225 }
12191226 } else {
1220- result , bifrostError = handleProviderRequest (provider , & req , key , req .Type )
1227+ result , bifrostError = handleProviderRequest (provider , req , key , req .Type )
12211228 if bifrostError != nil {
12221229 break // Don't retry client errors
12231230 }
@@ -1518,9 +1525,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
15181525 return supportedKeys [0 ], nil
15191526 }
15201527
1528+ selectedKey , err := bifrost .keySelector (ctx , supportedKeys , providerKey , model )
1529+ if err != nil {
1530+ return schemas.Key {}, err
1531+ }
1532+
1533+ return selectedKey , nil
1534+
1535+ }
1536+
1537+ func WeightedRandomKeySelector (ctx * context.Context , keys []schemas.Key , providerKey schemas.ModelProvider , model string ) (schemas.Key , error ) {
15211538 // Use a weighted random selection based on key weights
15221539 totalWeight := 0
1523- for _ , key := range supportedKeys {
1540+ for _ , key := range keys {
15241541 totalWeight += int (key .Weight * 100 ) // Convert float to int for better performance
15251542 }
15261543
@@ -1530,15 +1547,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
15301547
15311548 // Select key based on weight
15321549 currentWeight := 0
1533- for _ , key := range supportedKeys {
1550+ for _ , key := range keys {
15341551 currentWeight += int (key .Weight * 100 )
15351552 if randomValue < currentWeight {
15361553 return key , nil
15371554 }
15381555 }
15391556
15401557 // Fallback to first key if something goes wrong
1541- return supportedKeys [0 ], nil
1558+ return keys [0 ], nil
15421559}
15431560
15441561// Shutdown gracefully stops all workers when triggered.
@@ -1548,7 +1565,7 @@ func (bifrost *Bifrost) Shutdown() {
15481565
15491566 // Close all provider queues to signal workers to stop
15501567 bifrost .requestQueues .Range (func (key , value interface {}) bool {
1551- close (value .(chan ChannelMessage ))
1568+ close (value .(chan * ChannelMessage ))
15521569 return true
15531570 })
15541571
0 commit comments