Skip to content

Commit 5fe52fe

Browse files
TejasGhattePratham-Mishra04
authored andcommitted
feat: list models with all configured keys for provider
1 parent e4c45fa commit 5fe52fe

File tree

24 files changed

+585
-393
lines changed

24 files changed

+585
-393
lines changed

core/bifrost.go

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
227227
},
228228
}
229229
}
230+
if ctx == nil {
231+
ctx = bifrost.ctx
232+
}
230233

231234
request := &schemas.BifrostListModelsRequest{
232235
Provider: req.Provider,
@@ -258,23 +261,13 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
258261
baseProvider = config.CustomProviderConfig.BaseProviderType
259262
}
260263

261-
// Get API key for the provider if required
262-
key := schemas.Key{}
263-
if providerRequiresKey(baseProvider) {
264-
key, err = bifrost.selectKeyFromProviderForModel(&ctx, schemas.ListModelsRequest, req.Provider, "", baseProvider)
265-
if err != nil {
266-
return nil, &schemas.BifrostError{
267-
IsBifrostError: false,
268-
Error: &schemas.ErrorField{
269-
Message: err.Error(),
270-
Error: err,
271-
},
272-
}
273-
}
264+
keys, err := bifrost.getAllSupportedKeys(&ctx, req.Provider, baseProvider)
265+
if err != nil {
266+
return nil, newBifrostError(err)
274267
}
275268

276269
response, bifrostErr := executeRequestWithRetries(config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
277-
return provider.ListModels(ctx, key, request)
270+
return provider.ListModels(ctx, keys, request)
278271
}, schemas.ListModelsRequest, req.Provider, "")
279272
if bifrostErr != nil {
280273
return nil, bifrostErr
@@ -285,8 +278,6 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
285278
// ListAllModels lists all models from all configured providers.
286279
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
287280
func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
288-
startTime := time.Now()
289-
290281
if request == nil {
291282
request = &schemas.BifrostListModelsRequest{}
292283
}
@@ -296,12 +287,14 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
296287
return nil, &schemas.BifrostError{
297288
IsBifrostError: false,
298289
Error: &schemas.ErrorField{
299-
Message: "failed to get configured providers",
290+
Message: err.Error(),
300291
Error: err,
301292
},
302293
}
303294
}
304295

296+
startTime := time.Now()
297+
305298
// Accumulate all models from all providers
306299
allModels := make([]schemas.Model, 0)
307300
var firstError *schemas.BifrostError
@@ -321,7 +314,7 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
321314
for {
322315
iterations++
323316
if iterations > schemas.MaxPaginationRequests {
324-
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s", schemas.MaxPaginationRequests, providerKey))
317+
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s, please increase the page size", schemas.MaxPaginationRequests, providerKey))
325318
break
326319
}
327320

@@ -2324,6 +2317,42 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
23242317
bifrost.bifrostRequestPool.Put(req)
23252318
}
23262319

2320+
// getAllSupportedKeys retrieves all valid keys for a ListModels request.
2321+
// allowing the provider to aggregate results from multiple keys.
2322+
func (bifrost *Bifrost) getAllSupportedKeys(ctx *context.Context, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) {
2323+
// Check if key has been set in the context explicitly
2324+
if ctx != nil {
2325+
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
2326+
if ok {
2327+
// If a direct key is specified, return it as a single-element slice
2328+
return []schemas.Key{key}, nil
2329+
}
2330+
}
2331+
2332+
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
2333+
if err != nil {
2334+
return nil, err
2335+
}
2336+
2337+
if len(keys) == 0 {
2338+
return nil, fmt.Errorf("no keys found for provider: %v", providerKey)
2339+
}
2340+
2341+
// Filter keys for ListModels - only check if key has a value
2342+
var supportedKeys []schemas.Key
2343+
for _, k := range keys {
2344+
if strings.TrimSpace(k.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType) {
2345+
supportedKeys = append(supportedKeys, k)
2346+
}
2347+
}
2348+
2349+
if len(supportedKeys) == 0 {
2350+
return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey)
2351+
}
2352+
2353+
return supportedKeys, nil
2354+
}
2355+
23272356
// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
23282357
// It uses weighted random selection if multiple keys are available.
23292358
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {

core/providers/anthropic.go

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,6 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
121121
return getProviderName(schemas.Anthropic, provider.customProviderConfig)
122122
}
123123

124-
// parseStreamAnthropicError parses Anthropic streaming error responses.
125-
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
126-
statusCode := resp.StatusCode
127-
body, _ := io.ReadAll(resp.Body)
128-
resp.Body.Close()
129-
130-
var errorResp anthropic.AnthropicError
131-
if err := sonic.Unmarshal(body, &errorResp); err != nil {
132-
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
133-
}
134-
135-
return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
136-
}
137-
138124
// completeRequest sends a request to Anthropic's API and handles the response.
139125
// It constructs the API URL, sets up authentication, and processes the response.
140126
// Returns the response body or an error if the request fails.
@@ -188,14 +174,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
188174
return bodyCopy, latency, nil
189175
}
190176

191-
// ListModels performs a list models request to Anthropic's API.
192-
func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
193-
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
194-
return nil, err
195-
}
196-
197-
providerName := provider.GetProviderKey()
198-
177+
// listModelsByKey performs a list models request for a single key.
178+
// Returns the response and latency, or an error if the request fails.
179+
func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
199180
// Create request
200181
req := fasthttp.AcquireRequest()
201182
resp := fasthttp.AcquireResponse()
@@ -206,8 +187,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
206187
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)
207188

208189
// Build URL using centralized URL construction
209-
requestURL := anthropic.ToAnthropicListModelsURL(request, provider.networkConfig.BaseURL+"/v1/models")
210-
req.SetRequestURI(requestURL)
190+
req.SetRequestURI(fmt.Sprintf("%s/v1/models?limit=%d", provider.networkConfig.BaseURL, schemas.DefaultPageSize))
211191
req.Header.SetMethod(http.MethodGet)
212192
req.Header.SetContentType("application/json")
213193
req.Header.Set("x-api-key", key.Value)
@@ -221,14 +201,10 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
221201

222202
// Handle error response
223203
if resp.StatusCode() != fasthttp.StatusOK {
224-
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())))
225-
226204
var errorResp anthropic.AnthropicError
227-
228205
bifrostErr := handleProviderAPIError(resp, &errorResp)
229206
bifrostErr.Error.Type = &errorResp.Error.Type
230207
bifrostErr.Error.Message = errorResp.Error.Message
231-
232208
return nil, bifrostErr
233209
}
234210

@@ -240,11 +216,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
240216
}
241217

242218
// Create final response
243-
response := anthropicResponse.ToBifrostListModelsResponse(providerName)
244-
245-
// Set ExtraFields
246-
response.ExtraFields.Provider = providerName
247-
response.ExtraFields.RequestType = schemas.ListModelsRequest
219+
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey())
248220
response.ExtraFields.Latency = latency.Milliseconds()
249221

250222
// Set raw response if enabled
@@ -255,6 +227,23 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
255227
return response, nil
256228
}
257229

230+
// ListModels performs a list models request to Anthropic's API.
231+
// It fetches models using all provided keys and aggregates the results.
232+
// Uses a best-effort approach: continues with remaining keys even if some fail.
233+
// Requests are made concurrently for improved performance.
234+
func (provider *AnthropicProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
235+
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
236+
return nil, err
237+
}
238+
return handleMultipleListModelsRequests(
239+
ctx,
240+
keys,
241+
request,
242+
provider.listModelsByKey,
243+
provider.logger,
244+
)
245+
}
246+
258247
// TextCompletion performs a text completion request to Anthropic's API.
259248
// It formats the request, sends it to Anthropic, and processes the response.
260249
// Returns a BifrostResponse containing the completion results or an error if the request fails.
@@ -852,3 +841,17 @@ func (provider *AnthropicProvider) Transcription(ctx context.Context, key schema
852841
func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
853842
return nil, newUnsupportedOperationError("transcription stream", "anthropic")
854843
}
844+
845+
// parseStreamAnthropicError parses Anthropic streaming error responses.
846+
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
847+
statusCode := resp.StatusCode
848+
body, _ := io.ReadAll(resp.Body)
849+
resp.Body.Close()
850+
851+
var errorResp anthropic.AnthropicError
852+
if err := sonic.Unmarshal(body, &errorResp); err != nil {
853+
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
854+
}
855+
856+
return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
857+
}

core/providers/azure.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
123123
return bodyCopy, latency, nil
124124
}
125125

126-
// ListModels performs a list models request to Azure's API.
127-
// It retrieves all models accessible by the Azure OpenAI resource
128-
func (provider *AzureProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
126+
// listModelsForKey performs a list models request for a single key.
127+
128+
// Returns the response and latency, or an error if the request fails.
129+
func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
129130
// Validate Azure key configuration
130131
if key.AzureKeyConfig == nil {
131132
return nil, newConfigurationError("azure key config not set", schemas.Azure)
@@ -174,8 +175,8 @@ func (provider *AzureProvider) ListModels(ctx context.Context, key schemas.Key,
174175

175176
// Handle error response
176177
if resp.StatusCode() != fasthttp.StatusOK {
177-
provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body())))
178-
return nil, parseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "")
178+
bifrostErr := parseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "")
179+
return nil, bifrostErr
179180
}
180181

181182
// Read the response body and copy it before releasing the response
@@ -194,20 +195,27 @@ func (provider *AzureProvider) ListModels(ctx context.Context, key schemas.Key,
194195
if response == nil {
195196
return nil, newBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure)
196197
}
197-
198-
response = response.ApplyPagination(request.PageSize, request.PageToken)
199-
200-
response.ExtraFields.Provider = schemas.Azure
201198
response.ExtraFields.Latency = latency.Milliseconds()
202-
response.ExtraFields.RequestType = schemas.ListModelsRequest
203-
204199
if provider.sendBackRawResponse {
205200
response.ExtraFields.RawResponse = rawResponse
206201
}
207202

208203
return response, nil
209204
}
210205

206+
// ListModels performs a list models request to Azure's API.
207+
// It retrieves all models accessible by the Azure OpenAI resource
208+
// Requests are made concurrently for improved performance.
209+
func (provider *AzureProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
210+
return handleMultipleListModelsRequests(
211+
ctx,
212+
keys,
213+
request,
214+
provider.listModelsByKey,
215+
provider.logger,
216+
)
217+
}
218+
211219
// TextCompletion performs a text completion request to Azure's API.
212220
// It formats the request, sends it to Azure, and processes the response.
213221
// Returns a BifrostResponse containing the completion results or an error if the request fails.

0 commit comments

Comments
 (0)