Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 115 additions & 2 deletions core/schemas/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,114 @@ func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool {
}
}

// GetProviderCapabilities returns the capabilities for a given provider
func GetProviderCapabilities(provider ModelProvider) AllowedRequests {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TejasGhatte pls change its name to get GetBaseProviderCapabilities cause it only contains the base providers, not all

switch provider {
case OpenAI:
return AllowedRequests{
TextCompletion: false,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: true,
Speech: true,
SpeechStream: true,
Transcription: true,
TranscriptionStream: true,
}
case Anthropic:
return AllowedRequests{
TextCompletion: true,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: false,
Speech: false,
SpeechStream: false,
Transcription: false,
TranscriptionStream: false,
}
case Cohere:
return AllowedRequests{
TextCompletion: false,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: true,
Speech: false,
SpeechStream: false,
Transcription: false,
TranscriptionStream: false,
}
case Bedrock:
return AllowedRequests{
TextCompletion: true,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: true,
Speech: false,
SpeechStream: false,
Transcription: false,
TranscriptionStream: false,
}
case Gemini:
return AllowedRequests{
TextCompletion: false,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: true,
Speech: true,
SpeechStream: true,
Transcription: true,
TranscriptionStream: true,
}
default:
// Unknown providers: safest default is no capabilities; upstream validation should reject these.
return AllowedRequests{}
}
}

// GetDefaultAllowedRequestsForProvider returns a default AllowedRequests based on provider capabilities
// This ensures that when no explicit restrictions are set, operations are limited to what the provider supports
func GetDefaultAllowedRequestsForProvider(baseProvider ModelProvider) *AllowedRequests {
capabilities := GetProviderCapabilities(baseProvider)
return &capabilities
}

// ValidateAllowedRequests checks if the allowed requests are valid for the given base provider
// and silently sets any unsupported operations to false instead of returning an error
func ValidateAllowedRequests(allowedRequests *AllowedRequests, baseProvider ModelProvider) {
if allowedRequests == nil {
// nil means all operations allowed, which is always valid
return
}

capabilities := GetProviderCapabilities(baseProvider)

// Silently set unsupported operations to false instead of returning an error
if allowedRequests.TextCompletion && !capabilities.TextCompletion {
allowedRequests.TextCompletion = false
}
if allowedRequests.ChatCompletion && !capabilities.ChatCompletion {
allowedRequests.ChatCompletion = false
}
if allowedRequests.ChatCompletionStream && !capabilities.ChatCompletionStream {
allowedRequests.ChatCompletionStream = false
}
if allowedRequests.Embedding && !capabilities.Embedding {
allowedRequests.Embedding = false
}
if allowedRequests.Speech && !capabilities.Speech {
allowedRequests.Speech = false
}
if allowedRequests.SpeechStream && !capabilities.SpeechStream {
allowedRequests.SpeechStream = false
}
if allowedRequests.Transcription && !capabilities.Transcription {
allowedRequests.Transcription = false
}
if allowedRequests.TranscriptionStream && !capabilities.TranscriptionStream {
allowedRequests.TranscriptionStream = false
}
}

type CustomProviderConfig struct {
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
Expand All @@ -131,8 +239,13 @@ type CustomProviderConfig struct {

// IsOperationAllowed checks if a specific operation is allowed for this custom provider
func (cpc *CustomProviderConfig) IsOperationAllowed(operation Operation) bool {
if cpc == nil || cpc.AllowedRequests == nil {
return true // Default to allowed if no restrictions
if cpc == nil {
return true // Default to allowed if no custom provider config
}
if cpc.AllowedRequests == nil {
// Use provider capabilities as default when no explicit restrictions are set
defaultAllowedRequests := GetDefaultAllowedRequestsForProvider(cpc.BaseProviderType)
return defaultAllowedRequests.IsOperationAllowed(operation)
}
return cpc.AllowedRequests.IsOperationAllowed(operation)
}
Expand Down
67 changes: 67 additions & 0 deletions docs/features/custom-providers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,73 @@ curl --location 'http://localhost:8080/api/providers' \
}
```

</Tab>

<Tab title="Using Go SDK">

```go
type MyAccount struct{}

const OpenAICustom = schemas.ModelProvider("openai-custom")

func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, OpenAICustom}, nil
}

func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
switch provider {
case schemas.OpenAI:
return []schemas.Key{{
Value: os.Getenv("OPENAI_API_KEY"),
Models: []string{},
Weight: 1.0,
}}, nil
case schemas.Anthropic:
return []schemas.Key{{
Value: os.Getenv("ANTHROPIC_API_KEY"),
Models: []string{},
Weight: 1.0,
}}, nil
case OpenAICustom:
return []schemas.Key{{
Value: os.Getenv("PROVIDER_API_KEY"),
Models: []string{},
Weight: 1.0,
}}, nil
}
return nil, fmt.Errorf("provider %s not supported", provider)
}

func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) {
switch provider {
case OpenAICustom:
// Return config with base provider and allowed requests
return &schemas.ProviderConfig{
NetworkConfig: schemas.DefaultNetworkConfig,
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
BaseProviderType: schemas.OpenAI,
AllowedRequests: &schemas.AllowedRequests{
TextCompletion: false,
ChatCompletion: true,
ChatCompletionStream: true,
Embedding: true,
Speech: false,
SpeechStream: false,
Transcription: false,
TranscriptionStream: false,
},
}, nil
default:
// Return default config for other providers
return &schemas.ProviderConfig{
NetworkConfig: schemas.DefaultNetworkConfig,
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
}, nil
}
}
```


</Tab>

</Tabs>
Expand Down
25 changes: 8 additions & 17 deletions transports/bifrost-http/handlers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,6 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) {
// baseProvider tracks the effective base provider type for validations/keys
baseProvider := req.Provider
if req.CustomProviderConfig != nil {
// custom provider key should not be same as standard provider names
if bifrost.IsStandardProvider(baseProvider) {
SendError(ctx, fasthttp.StatusBadRequest, "Custom provider cannot be same as a standard provider", h.logger)
return
}

if req.CustomProviderConfig.BaseProviderType == "" {
SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType is required when CustomProviderConfig is provided", h.logger)
return
}

// check if base provider is a supported base provider
if !bifrost.IsSupportedBaseProvider(req.CustomProviderConfig.BaseProviderType) {
SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType must be a standard provider", h.logger)
return
}

// CustomProviderKey is internally set by Bifrost, no validation needed
baseProvider = req.CustomProviderConfig.BaseProviderType
}
Expand Down Expand Up @@ -219,6 +202,14 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) {
CustomProviderConfig: req.CustomProviderConfig,
}

// Validate custom provider configuration (including AllowedRequests sanitization)
if req.CustomProviderConfig != nil {
if err := lib.ValidateCustomProvider(config, req.Provider); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid custom provider config: %v", err), h.logger)
return
}
}

// Add provider to store (env vars will be processed by store)
if err := h.store.AddProvider(req.Provider, config); err != nil {
h.logger.Warn(fmt.Sprintf("Failed to add provider %s: %v", req.Provider, err))
Expand Down
23 changes: 22 additions & 1 deletion transports/bifrost-http/lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,26 @@ func ValidateCustomProvider(config configstore.ProviderConfig, provider schemas.
if !bifrost.IsSupportedBaseProvider(cpc.BaseProviderType) {
return fmt.Errorf("custom provider validation failed: unsupported base_provider_type: %s", cpc.BaseProviderType)
}

// Validate allowed requests against base provider capabilities
if cpc.AllowedRequests != nil {
schemas.ValidateAllowedRequests(cpc.AllowedRequests, cpc.BaseProviderType)
} else {
// Set default allowed requests based on provider capabilities when nil
// This ensures the auto-disable feature works even when no explicit restrictions are set
caps := schemas.GetProviderCapabilities(cpc.BaseProviderType)
cpc.AllowedRequests = &schemas.AllowedRequests{
TextCompletion: caps.TextCompletion,
ChatCompletion: caps.ChatCompletion,
ChatCompletionStream: caps.ChatCompletionStream,
Embedding: caps.Embedding,
Speech: caps.Speech,
SpeechStream: caps.SpeechStream,
Transcription: caps.Transcription,
TranscriptionStream: caps.TranscriptionStream,
}
}

return nil
}

Expand Down Expand Up @@ -1882,7 +1902,8 @@ func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.Provider
provider, existingCPC.BaseProviderType, newCPC.BaseProviderType)
}

return nil
// Validate the new config (including AllowedRequests sanitization)
return ValidateCustomProvider(newConfig, provider)
}

func (s *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error {
Expand Down
13 changes: 12 additions & 1 deletion ui/app/providers/configure/page.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"use client";

import { closeConfigureDialog, useAppDispatch, useAppSelector, useGetProvidersQuery } from "@/lib/store";
import { closeConfigureDialog, setSelectedProvider, useAppDispatch, useAppSelector, useGetProvidersQuery } from "@/lib/store";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import ProviderForm from "./provider-form";

export default function ConfigurePage() {
Expand All @@ -10,6 +11,16 @@ export default function ConfigurePage() {
const selectedProvider = useAppSelector((state) => state.provider.selectedProvider);
const { data: providersData, refetch } = useGetProvidersQuery();

// Auto-select first available provider when no provider is selected and data is loaded
useEffect(() => {
if (!selectedProvider && providersData?.providers && providersData.providers.length > 0) {
const providerToSelect = providersData.providers[0];

// Update Redux state to select this provider
dispatch(setSelectedProvider(providerToSelect));
}
}, [selectedProvider, providersData?.providers, dispatch]);

const handleSave = () => {
refetch();
dispatch(closeConfigureDialog());
Expand Down
Loading