Skip to content

Commit 6a929e7

Browse files
authored
feat: list models with all configured keys for provider (#700)
## Summary Enhance the ListModels functionality to aggregate results from multiple API keys, improving model discovery and availability. ## Changes - Modified `ListModels` interface to accept multiple keys instead of a single key - Implemented key aggregation logic that combines model lists from multiple API keys - Added deduplication of models based on model ID when aggregating results - Created a best-effort approach that continues processing remaining keys even if some fail - Added a new helper function `getAllSupportedKeys` to retrieve all valid keys for a provider - Updated all provider implementations to support the multi-key approach ## Type of change - [ ] Bug fix - [x] Feature - [ ] Refactor - [ ] Documentation - [ ] Chore/CI ## Affected areas - [x] Core (Go) - [ ] Transports (HTTP) - [x] Providers/Integrations - [ ] Plugins - [ ] UI (Next.js) - [ ] Docs ## How to test Test the ListModels endpoint with multiple API keys configured for the same provider: ```sh # Core/Transports go version go test ./... # Test with multiple keys for a provider (e.g., OpenAI) curl -X GET "http://localhost:8000/v1/models?provider=openai" -H "Authorization: Bearer your-bifrost-token" ``` ## Breaking changes - [ ] Yes - [x] No ## Related issues Improves model discovery by aggregating results from multiple API keys. ## Security considerations The implementation maintains the same security model as before, with proper error handling to avoid leaking sensitive information. ## Checklist - [x] I added/updated tests where appropriate - [x] I verified builds succeed (Go and UI) - [x] I verified the CI pipeline passes locally if applicable
1 parent e4c45fa commit 6a929e7

File tree

24 files changed

+675
-433
lines changed

24 files changed

+675
-433
lines changed

core/bifrost.go

Lines changed: 119 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
227227
},
228228
}
229229
}
230+
if ctx == nil {
231+
ctx = bifrost.ctx
232+
}
230233

231234
request := &schemas.BifrostListModelsRequest{
232235
Provider: req.Provider,
@@ -258,23 +261,13 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
258261
baseProvider = config.CustomProviderConfig.BaseProviderType
259262
}
260263

261-
// Get API key for the provider if required
262-
key := schemas.Key{}
263-
if providerRequiresKey(baseProvider) {
264-
key, err = bifrost.selectKeyFromProviderForModel(&ctx, schemas.ListModelsRequest, req.Provider, "", baseProvider)
265-
if err != nil {
266-
return nil, &schemas.BifrostError{
267-
IsBifrostError: false,
268-
Error: &schemas.ErrorField{
269-
Message: err.Error(),
270-
Error: err,
271-
},
272-
}
273-
}
264+
keys, err := bifrost.getAllSupportedKeys(&ctx, req.Provider, baseProvider)
265+
if err != nil {
266+
return nil, newBifrostError(err)
274267
}
275268

276269
response, bifrostErr := executeRequestWithRetries(config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
277-
return provider.ListModels(ctx, key, request)
270+
return provider.ListModels(ctx, keys, request)
278271
}, schemas.ListModelsRequest, req.Provider, "")
279272
if bifrostErr != nil {
280273
return nil, bifrostErr
@@ -285,8 +278,6 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
285278
// ListAllModels lists all models from all configured providers.
286279
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
287280
func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
288-
startTime := time.Now()
289-
290281
if request == nil {
291282
request = &schemas.BifrostListModelsRequest{}
292283
}
@@ -296,64 +287,102 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
296287
return nil, &schemas.BifrostError{
297288
IsBifrostError: false,
298289
Error: &schemas.ErrorField{
299-
Message: "failed to get configured providers",
290+
Message: err.Error(),
300291
Error: err,
301292
},
302293
}
303294
}
304295

305-
// Accumulate all models from all providers
306-
allModels := make([]schemas.Model, 0)
307-
var firstError *schemas.BifrostError
296+
startTime := time.Now()
297+
298+
// Result structure for collecting provider responses
299+
type providerResult struct {
300+
models []schemas.Model
301+
err *schemas.BifrostError
302+
}
303+
304+
results := make(chan providerResult, len(providerKeys))
305+
var wg sync.WaitGroup
308306

307+
// Launch concurrent requests for all providers
309308
for _, providerKey := range providerKeys {
310309
if strings.TrimSpace(string(providerKey)) == "" {
311310
continue
312311
}
313312

314-
// Create request for this provider with limit of 1000
315-
providerRequest := &schemas.BifrostListModelsRequest{
316-
Provider: providerKey,
317-
PageSize: schemas.DefaultPageSize,
318-
}
313+
wg.Add(1)
314+
go func(providerKey schemas.ModelProvider) {
315+
defer wg.Done()
319316

320-
iterations := 0
321-
for {
322-
iterations++
323-
if iterations > schemas.MaxPaginationRequests {
324-
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s", schemas.MaxPaginationRequests, providerKey))
325-
break
317+
providerModels := make([]schemas.Model, 0)
318+
var providerErr *schemas.BifrostError
319+
320+
// Create request for this provider with limit of 1000
321+
providerRequest := &schemas.BifrostListModelsRequest{
322+
Provider: providerKey,
323+
PageSize: schemas.DefaultPageSize,
326324
}
327325

328-
response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
329-
if bifrostErr != nil {
330-
// Log the error but continue with other providers
331-
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
332-
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
333-
!strings.Contains(bifrostErr.Error.Message, "not supported") {
334-
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
326+
iterations := 0
327+
for {
328+
// check for context cancellation
329+
select {
330+
case <-ctx.Done():
331+
bifrost.logger.Warn(fmt.Sprintf("context cancelled for provider %s", providerKey))
332+
return
333+
default:
335334
}
336-
if firstError == nil {
337-
firstError = bifrostErr
335+
336+
iterations++
337+
if iterations > schemas.MaxPaginationRequests {
338+
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s, please increase the page size", schemas.MaxPaginationRequests, providerKey))
339+
break
338340
}
339-
break
340-
}
341341

342-
if response == nil {
343-
break
344-
}
342+
response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
343+
if bifrostErr != nil {
344+
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
345+
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
346+
!strings.Contains(bifrostErr.Error.Message, "not supported") {
347+
providerErr = bifrostErr
348+
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
349+
}
350+
break
351+
}
345352

346-
if len(response.Data) > 0 {
347-
allModels = append(allModels, response.Data...)
348-
}
353+
if response == nil || len(response.Data) == 0 {
354+
break
355+
}
349356

350-
// Check if there are more pages
351-
if response.NextPageToken == "" {
352-
break
357+
providerModels = append(providerModels, response.Data...)
358+
359+
// Check if there are more pages
360+
if response.NextPageToken == "" {
361+
break
362+
}
363+
364+
// Set the page token for the next request
365+
providerRequest.PageToken = response.NextPageToken
353366
}
354367

355-
// Set the page token for the next request
356-
providerRequest.PageToken = response.NextPageToken
368+
results <- providerResult{models: providerModels, err: providerErr}
369+
}(providerKey)
370+
}
371+
372+
// Wait for all goroutines to complete
373+
wg.Wait()
374+
close(results)
375+
376+
// Accumulate all models from all providers
377+
allModels := make([]schemas.Model, 0)
378+
var firstError *schemas.BifrostError
379+
380+
for result := range results {
381+
if len(result.models) > 0 {
382+
allModels = append(allModels, result.models...)
383+
}
384+
if result.err != nil && firstError == nil {
385+
firstError = result.err
357386
}
358387
}
359388

@@ -367,15 +396,12 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
367396
return allModels[i].ID < allModels[j].ID
368397
})
369398

370-
// Calculate total elapsed time
371-
elapsedTime := time.Since(startTime).Milliseconds()
372-
373399
// Return aggregated response with accumulated latency
374400
response := &schemas.BifrostListModelsResponse{
375401
Data: allModels,
376402
ExtraFields: schemas.BifrostResponseExtraFields{
377403
RequestType: schemas.ListModelsRequest,
378-
Latency: elapsedTime,
404+
Latency: time.Since(startTime).Milliseconds(),
379405
},
380406
}
381407

@@ -2324,6 +2350,42 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
23242350
bifrost.bifrostRequestPool.Put(req)
23252351
}
23262352

2353+
// getAllSupportedKeys retrieves all valid keys for a ListModels request.
2354+
// allowing the provider to aggregate results from multiple keys.
2355+
func (bifrost *Bifrost) getAllSupportedKeys(ctx *context.Context, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) {
2356+
// Check if key has been set in the context explicitly
2357+
if ctx != nil {
2358+
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
2359+
if ok {
2360+
// If a direct key is specified, return it as a single-element slice
2361+
return []schemas.Key{key}, nil
2362+
}
2363+
}
2364+
2365+
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
2366+
if err != nil {
2367+
return nil, err
2368+
}
2369+
2370+
if len(keys) == 0 {
2371+
return nil, fmt.Errorf("no keys found for provider: %v", providerKey)
2372+
}
2373+
2374+
// Filter keys for ListModels - only check if key has a value
2375+
var supportedKeys []schemas.Key
2376+
for _, k := range keys {
2377+
if strings.TrimSpace(k.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType) {
2378+
supportedKeys = append(supportedKeys, k)
2379+
}
2380+
}
2381+
2382+
if len(supportedKeys) == 0 {
2383+
return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey)
2384+
}
2385+
2386+
return supportedKeys, nil
2387+
}
2388+
23272389
// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
23282390
// It uses weighted random selection if multiple keys are available.
23292391
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {

core/providers/anthropic.go

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,6 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
121121
return getProviderName(schemas.Anthropic, provider.customProviderConfig)
122122
}
123123

124-
// parseStreamAnthropicError parses Anthropic streaming error responses.
125-
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
126-
statusCode := resp.StatusCode
127-
body, _ := io.ReadAll(resp.Body)
128-
resp.Body.Close()
129-
130-
var errorResp anthropic.AnthropicError
131-
if err := sonic.Unmarshal(body, &errorResp); err != nil {
132-
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
133-
}
134-
135-
return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
136-
}
137-
138124
// completeRequest sends a request to Anthropic's API and handles the response.
139125
// It constructs the API URL, sets up authentication, and processes the response.
140126
// Returns the response body or an error if the request fails.
@@ -188,14 +174,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
188174
return bodyCopy, latency, nil
189175
}
190176

191-
// ListModels performs a list models request to Anthropic's API.
192-
func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
193-
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
194-
return nil, err
195-
}
196-
197-
providerName := provider.GetProviderKey()
198-
177+
// listModelsByKey performs a list models request for a single key.
178+
// Returns the response and latency, or an error if the request fails.
179+
func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
199180
// Create request
200181
req := fasthttp.AcquireRequest()
201182
resp := fasthttp.AcquireResponse()
@@ -206,8 +187,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
206187
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
207188

208189
// Build URL using centralized URL construction
209-
requestURL := anthropic.ToAnthropicListModelsURL(request, provider.networkConfig.BaseURL+"/v1/models")
210-
req.SetRequestURI(requestURL)
190+
req.SetRequestURI(fmt.Sprintf("%s/v1/models?limit=%d", provider.networkConfig.BaseURL, schemas.DefaultPageSize))
211191
req.Header.SetMethod(http.MethodGet)
212192
req.Header.SetContentType("application/json")
213193
req.Header.Set("x-api-key", key.Value)
@@ -221,14 +201,10 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
221201

222202
// Handle error response
223203
if resp.StatusCode() != fasthttp.StatusOK {
224-
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())))
225-
226204
var errorResp anthropic.AnthropicError
227-
228205
bifrostErr := handleProviderAPIError(resp, &errorResp)
229206
bifrostErr.Error.Type = &errorResp.Error.Type
230207
bifrostErr.Error.Message = errorResp.Error.Message
231-
232208
return nil, bifrostErr
233209
}
234210

@@ -240,11 +216,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
240216
}
241217

242218
// Create final response
243-
response := anthropicResponse.ToBifrostListModelsResponse(providerName)
244-
245-
// Set ExtraFields
246-
response.ExtraFields.Provider = providerName
247-
response.ExtraFields.RequestType = schemas.ListModelsRequest
219+
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey())
248220
response.ExtraFields.Latency = latency.Milliseconds()
249221

250222
// Set raw response if enabled
@@ -255,6 +227,23 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
255227
return response, nil
256228
}
257229

230+
// ListModels performs a list models request to Anthropic's API.
231+
// It fetches models using all provided keys and aggregates the results.
232+
// Uses a best-effort approach: continues with remaining keys even if some fail.
233+
// Requests are made concurrently for improved performance.
234+
func (provider *AnthropicProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
235+
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
236+
return nil, err
237+
}
238+
return handleMultipleListModelsRequests(
239+
ctx,
240+
keys,
241+
request,
242+
provider.listModelsByKey,
243+
provider.logger,
244+
)
245+
}
246+
258247
// TextCompletion performs a text completion request to Anthropic's API.
259248
// It formats the request, sends it to Anthropic, and processes the response.
260249
// Returns a BifrostResponse containing the completion results or an error if the request fails.
@@ -852,3 +841,17 @@ func (provider *AnthropicProvider) Transcription(ctx context.Context, key schema
852841
func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
853842
return nil, newUnsupportedOperationError("transcription stream", "anthropic")
854843
}
844+
845+
// parseStreamAnthropicError parses Anthropic streaming error responses.
846+
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
847+
statusCode := resp.StatusCode
848+
body, _ := io.ReadAll(resp.Body)
849+
resp.Body.Close()
850+
851+
var errorResp anthropic.AnthropicError
852+
if err := sonic.Unmarshal(body, &errorResp); err != nil {
853+
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
854+
}
855+
856+
return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
857+
}

0 commit comments

Comments
 (0)