@@ -46,6 +46,7 @@ type Bifrost struct {
4646 logger schemas.Logger // logger instance, default logger is used if not provided
4747 mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
4848 dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
49+ keySelector schemas.KeySelector // Custom key selector function
4950}
5051
5152// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -86,10 +87,15 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8687 plugins : atomic.Pointer [[]schemas.Plugin ]{},
8788 requestQueues : sync.Map {},
8889 waitGroups : sync.Map {},
90+ keySelector : config .KeySelector ,
8991 }
9092 bifrost .plugins .Store (& config .Plugins )
9193 bifrost .dropExcessRequests .Store (config .DropExcessRequests )
9294
95+ if bifrost .keySelector == nil {
96+ bifrost .keySelector = WeightedRandomKeySelector
97+ }
98+
9399 // Initialize object pools
94100 bifrost .channelMessagePool = sync.Pool {
95101 New : func () interface {} {
@@ -632,12 +638,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
632638 return bifrost .prepareProvider (providerKey , providerConfig )
633639 }
634640
635- oldQueue := oldQueueValue .(chan ChannelMessage )
641+ oldQueue := oldQueueValue .(chan * ChannelMessage )
636642
637643 bifrost .logger .Debug ("gracefully stopping existing workers for provider %s" , providerKey )
638644
639645 // Step 1: Create new queue with updated buffer size
640- newQueue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
646+ newQueue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
641647
642648 // Step 2: Transfer any buffered requests from old queue to new queue
643649 // This prevents request loss during the transition
@@ -653,7 +659,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
653659 // New queue is full, handle this request in a goroutine
654660 // This is unlikely with proper buffer sizing but provides safety
655661 transferWaitGroup .Add (1 )
656- go func (m ChannelMessage ) {
662+ go func (m * ChannelMessage ) {
657663 defer transferWaitGroup .Done ()
658664 select {
659665 case newQueue <- m :
@@ -1017,7 +1023,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10171023 return fmt .Errorf ("failed to get config for provider: %v" , err )
10181024 }
10191025
1020- queue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
1026+ queue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
10211027
10221028 bifrost .requestQueues .Store (providerKey , queue )
10231029
@@ -1044,13 +1050,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10441050// If the queue doesn't exist, it creates one at runtime and initializes the provider,
10451051// given the provider config is provided in the account interface implementation.
10461052// This function uses read locks to prevent race conditions during provider updates.
1047- func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan ChannelMessage , error ) {
1053+ func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan * ChannelMessage , error ) {
10481054 // Use read lock to allow concurrent reads but prevent concurrent updates
10491055 providerMutex := bifrost .getProviderMutex (providerKey )
10501056 providerMutex .RLock ()
10511057
10521058 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
1053- queue := queueValue .(chan ChannelMessage )
1059+ queue := queueValue .(chan * ChannelMessage )
10541060 providerMutex .RUnlock ()
10551061 return queue , nil
10561062 }
@@ -1063,7 +1069,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10631069
10641070 // Double-check after acquiring write lock (another goroutine might have created it)
10651071 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
1066- queue := queueValue .(chan ChannelMessage )
1072+ queue := queueValue .(chan * ChannelMessage )
10671073 return queue , nil
10681074 }
10691075
@@ -1079,7 +1085,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10791085 }
10801086
10811087 queueValue , _ := bifrost .requestQueues .Load (providerKey )
1082- queue := queueValue .(chan ChannelMessage )
1088+ queue := queueValue .(chan * ChannelMessage )
10831089
10841090 return queue , nil
10851091}
@@ -1341,9 +1347,8 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13411347
13421348 msg := bifrost .getChannelMessage (* preReq )
13431349 msg .Context = ctx
1344- startTime := time .Now ()
13451350 select {
1346- case queue <- * msg :
1351+ case queue <- msg :
13471352 // Message was sent successfully
13481353 case <- ctx .Done ():
13491354 bifrost .releaseChannelMessage (msg )
@@ -1355,7 +1360,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13551360 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
13561361 }
13571362 select {
1358- case queue <- * msg :
1363+ case queue <- msg :
13591364 // Message was sent successfully
13601365 case <- ctx .Done ():
13611366 bifrost .releaseChannelMessage (msg )
@@ -1368,11 +1373,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13681373 pluginCount := len (* bifrost .plugins .Load ())
13691374 select {
13701375 case result = <- msg .Response :
1371- latency := time .Since (startTime ).Milliseconds ()
1372- if result .ExtraFields .Latency == nil {
1373- result .ExtraFields .Latency = Ptr (float64 (latency ))
1374- }
1375- resp , bifrostErr := pipeline .RunPostHooks (& ctx , result , nil , pluginCount )
1376+ resp , bifrostErr := pipeline .RunPostHooks (& msg .Context , result , nil , pluginCount )
13761377 if bifrostErr != nil {
13771378 bifrost .releaseChannelMessage (msg )
13781379 return nil , bifrostErr
@@ -1381,7 +1382,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13811382 return resp , nil
13821383 case bifrostErrVal := <- msg .Err :
13831384 bifrostErrPtr := & bifrostErrVal
1384- resp , bifrostErrPtr = pipeline .RunPostHooks (& ctx , nil , bifrostErrPtr , pluginCount )
1385+ resp , bifrostErrPtr = pipeline .RunPostHooks (& msg . Context , nil , bifrostErrPtr , pluginCount )
13851386 bifrost .releaseChannelMessage (msg )
13861387 if bifrostErrPtr != nil {
13871388 return nil , bifrostErrPtr
@@ -1463,7 +1464,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14631464 msg .Context = ctx
14641465
14651466 select {
1466- case queue <- * msg :
1467+ case queue <- msg :
14671468 // Message was sent successfully
14681469 case <- ctx .Done ():
14691470 bifrost .releaseChannelMessage (msg )
@@ -1475,7 +1476,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14751476 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
14761477 }
14771478 select {
1478- case queue <- * msg :
1479+ case queue <- msg :
14791480 // Message was sent successfully
14801481 case <- ctx .Done ():
14811482 bifrost .releaseChannelMessage (msg )
@@ -1506,7 +1507,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
15061507
15071508// requestWorker handles incoming requests from the queue for a specific provider.
15081509// It manages retries, error handling, and response processing.
1509- func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan ChannelMessage ) {
1510+ func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan * ChannelMessage ) {
15101511 defer func () {
15111512 if waitGroupValue , ok := bifrost .waitGroups .Load (provider .GetProviderKey ()); ok {
15121513 waitGroup := waitGroupValue .(* sync.WaitGroup )
@@ -1541,6 +1542,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15411542 }
15421543 continue
15431544 }
1545+ req .Context = context .WithValue (req .Context , schemas .BifrostContextKeySelectedKey , key .ID )
15441546 }
15451547
15461548 // Track attempts
@@ -1576,12 +1578,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15761578
15771579 // Attempt the request
15781580 if IsStreamRequestType (req .RequestType ) {
1579- stream , bifrostError = handleProviderStreamRequest (provider , & req , key , postHookRunner )
1581+ stream , bifrostError = handleProviderStreamRequest (provider , req , key , postHookRunner )
15801582 if bifrostError != nil && ! bifrostError .IsBifrostError {
15811583 break // Don't retry client errors
15821584 }
15831585 } else {
1584- result , bifrostError = handleProviderRequest (provider , & req , key )
1586+ result , bifrostError = handleProviderRequest (provider , req , key )
15851587 if bifrostError != nil {
15861588 break // Don't retry client errors
15871589 }
@@ -1930,9 +1932,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19301932 return supportedKeys [0 ], nil
19311933 }
19321934
1935+ selectedKey , err := bifrost .keySelector (ctx , supportedKeys , providerKey , model )
1936+ if err != nil {
1937+ return schemas.Key {}, err
1938+ }
1939+
1940+ return selectedKey , nil
1941+
1942+ }
1943+
1944+ func WeightedRandomKeySelector (ctx * context.Context , keys []schemas.Key , providerKey schemas.ModelProvider , model string ) (schemas.Key , error ) {
19331945 // Use a weighted random selection based on key weights
19341946 totalWeight := 0
1935- for _ , key := range supportedKeys {
1947+ for _ , key := range keys {
19361948 totalWeight += int (key .Weight * 100 ) // Convert float to int for better performance
19371949 }
19381950
@@ -1942,15 +1954,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19421954
19431955 // Select key based on weight
19441956 currentWeight := 0
1945- for _ , key := range supportedKeys {
1957+ for _ , key := range keys {
19461958 currentWeight += int (key .Weight * 100 )
19471959 if randomValue < currentWeight {
19481960 return key , nil
19491961 }
19501962 }
19511963
19521964 // Fallback to first key if something goes wrong
1953- return supportedKeys [0 ], nil
1965+ return keys [0 ], nil
19541966}
19551967
19561968// Shutdown gracefully stops all workers when triggered.
@@ -1960,7 +1972,7 @@ func (bifrost *Bifrost) Shutdown() {
19601972
19611973 // Close all provider queues to signal workers to stop
19621974 bifrost .requestQueues .Range (func (key , value interface {}) bool {
1963- close (value .(chan ChannelMessage ))
1975+ close (value .(chan * ChannelMessage ))
19641976 return true
19651977 })
19661978
0 commit comments