Skip to content

Commit e5ac17c

Browse files
committed
fix: base provider level request level restriction
1 parent 2a67922 commit e5ac17c

File tree

9 files changed

+338
-36
lines changed

9 files changed

+338
-36
lines changed

core/schemas/provider.go

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,114 @@ func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool {
123123
}
124124
}
125125

126+
// GetProviderCapabilities returns the capabilities for a given provider
127+
func GetProviderCapabilities(provider ModelProvider) AllowedRequests {
128+
switch provider {
129+
case OpenAI:
130+
return AllowedRequests{
131+
TextCompletion: false,
132+
ChatCompletion: true,
133+
ChatCompletionStream: true,
134+
Embedding: true,
135+
Speech: true,
136+
SpeechStream: true,
137+
Transcription: true,
138+
TranscriptionStream: true,
139+
}
140+
case Anthropic:
141+
return AllowedRequests{
142+
TextCompletion: true,
143+
ChatCompletion: true,
144+
ChatCompletionStream: true,
145+
Embedding: false,
146+
Speech: false,
147+
SpeechStream: false,
148+
Transcription: false,
149+
TranscriptionStream: false,
150+
}
151+
case Cohere:
152+
return AllowedRequests{
153+
TextCompletion: false,
154+
ChatCompletion: true,
155+
ChatCompletionStream: true,
156+
Embedding: true,
157+
Speech: false,
158+
SpeechStream: false,
159+
Transcription: false,
160+
TranscriptionStream: false,
161+
}
162+
case Bedrock:
163+
return AllowedRequests{
164+
TextCompletion: true,
165+
ChatCompletion: true,
166+
ChatCompletionStream: true,
167+
Embedding: true,
168+
Speech: false,
169+
SpeechStream: false,
170+
Transcription: false,
171+
TranscriptionStream: false,
172+
}
173+
case Gemini:
174+
return AllowedRequests{
175+
TextCompletion: false,
176+
ChatCompletion: true,
177+
ChatCompletionStream: true,
178+
Embedding: true,
179+
Speech: true,
180+
SpeechStream: true,
181+
Transcription: true,
182+
TranscriptionStream: true,
183+
}
184+
default:
185+
// Unknown providers: safest default is no capabilities; upstream validation should reject these.
186+
return AllowedRequests{}
187+
}
188+
}
189+
190+
// GetDefaultAllowedRequestsForProvider returns a default AllowedRequests based on provider capabilities
191+
// This ensures that when no explicit restrictions are set, operations are limited to what the provider supports
192+
func GetDefaultAllowedRequestsForProvider(baseProvider ModelProvider) *AllowedRequests {
193+
capabilities := GetProviderCapabilities(baseProvider)
194+
return &capabilities
195+
}
196+
197+
// ValidateAllowedRequests checks if the allowed requests are valid for the given base provider
198+
// and silently sets any unsupported operations to false instead of returning an error
199+
func ValidateAllowedRequests(allowedRequests *AllowedRequests, baseProvider ModelProvider) {
200+
if allowedRequests == nil {
201+
// nil means all operations allowed, which is always valid
202+
return
203+
}
204+
205+
capabilities := GetProviderCapabilities(baseProvider)
206+
207+
// Silently set unsupported operations to false instead of returning an error
208+
if allowedRequests.TextCompletion && !capabilities.TextCompletion {
209+
allowedRequests.TextCompletion = false
210+
}
211+
if allowedRequests.ChatCompletion && !capabilities.ChatCompletion {
212+
allowedRequests.ChatCompletion = false
213+
}
214+
if allowedRequests.ChatCompletionStream && !capabilities.ChatCompletionStream {
215+
allowedRequests.ChatCompletionStream = false
216+
}
217+
if allowedRequests.Embedding && !capabilities.Embedding {
218+
allowedRequests.Embedding = false
219+
}
220+
if allowedRequests.Speech && !capabilities.Speech {
221+
allowedRequests.Speech = false
222+
}
223+
if allowedRequests.SpeechStream && !capabilities.SpeechStream {
224+
allowedRequests.SpeechStream = false
225+
}
226+
if allowedRequests.Transcription && !capabilities.Transcription {
227+
allowedRequests.Transcription = false
228+
}
229+
if allowedRequests.TranscriptionStream && !capabilities.TranscriptionStream {
230+
allowedRequests.TranscriptionStream = false
231+
}
232+
}
233+
126234
type CustomProviderConfig struct {
127235
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
128236
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
@@ -131,8 +239,13 @@ type CustomProviderConfig struct {
131239

132240
// IsOperationAllowed checks if a specific operation is allowed for this custom provider
133241
func (cpc *CustomProviderConfig) IsOperationAllowed(operation Operation) bool {
134-
if cpc == nil || cpc.AllowedRequests == nil {
135-
return true // Default to allowed if no restrictions
242+
if cpc == nil {
243+
return true // Default to allowed if no custom provider config
244+
}
245+
if cpc.AllowedRequests == nil {
246+
// Use provider capabilities as default when no explicit restrictions are set
247+
defaultAllowedRequests := GetDefaultAllowedRequestsForProvider(cpc.BaseProviderType)
248+
return defaultAllowedRequests.IsOperationAllowed(operation)
136249
}
137250
return cpc.AllowedRequests.IsOperationAllowed(operation)
138251
}

docs/features/custom-providers.mdx

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,73 @@ curl --location 'http://localhost:8080/api/providers' \
107107
}
108108
```
109109

110+
</Tab>
111+
112+
<Tab title="Using Go SDK">
113+
114+
```go
115+
type MyAccount struct{}
116+
117+
const OpenAICustom = schemas.ModelProvider("openai-custom")
118+
119+
func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
120+
return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, OpenAICustom}, nil
121+
}
122+
123+
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
124+
switch provider {
125+
case schemas.OpenAI:
126+
return []schemas.Key{{
127+
Value: os.Getenv("OPENAI_API_KEY"),
128+
Models: []string{},
129+
Weight: 1.0,
130+
}}, nil
131+
case schemas.Anthropic:
132+
return []schemas.Key{{
133+
Value: os.Getenv("ANTHROPIC_API_KEY"),
134+
Models: []string{},
135+
Weight: 1.0,
136+
}}, nil
137+
case OpenAICustom:
138+
return []schemas.Key{{
139+
Value: os.Getenv("PROVIDER_API_KEY"),
140+
Models: []string{},
141+
Weight: 1.0,
142+
}}, nil
143+
}
144+
return nil, fmt.Errorf("provider %s not supported", provider)
145+
}
146+
147+
func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) {
148+
switch provider {
149+
case OpenAICustom:
150+
// Return config with base provider and allowed requests
151+
return &schemas.ProviderConfig{
152+
NetworkConfig: schemas.DefaultNetworkConfig,
153+
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
154+
BaseProviderType: schemas.OpenAI,
155+
AllowedRequests: &schemas.AllowedRequests{
156+
TextCompletion: false,
157+
ChatCompletion: true,
158+
ChatCompletionStream: true,
159+
Embedding: true,
160+
Speech: false,
161+
SpeechStream: false,
162+
Transcription: false,
163+
TranscriptionStream: false,
164+
},
165+
}, nil
166+
default:
167+
// Return default config for other providers
168+
return &schemas.ProviderConfig{
169+
NetworkConfig: schemas.DefaultNetworkConfig,
170+
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
171+
}, nil
172+
}
173+
}
174+
```
175+
176+
110177
</Tab>
111178

112179
</Tabs>

transports/bifrost-http/handlers/providers.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,23 +161,6 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) {
161161
// baseProvider tracks the effective base provider type for validations/keys
162162
baseProvider := req.Provider
163163
if req.CustomProviderConfig != nil {
164-
// custom provider key should not be same as standard provider names
165-
if bifrost.IsStandardProvider(baseProvider) {
166-
SendError(ctx, fasthttp.StatusBadRequest, "Custom provider cannot be same as a standard provider", h.logger)
167-
return
168-
}
169-
170-
if req.CustomProviderConfig.BaseProviderType == "" {
171-
SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType is required when CustomProviderConfig is provided", h.logger)
172-
return
173-
}
174-
175-
// check if base provider is a supported base provider
176-
if !bifrost.IsSupportedBaseProvider(req.CustomProviderConfig.BaseProviderType) {
177-
SendError(ctx, fasthttp.StatusBadRequest, "BaseProviderType must be a standard provider", h.logger)
178-
return
179-
}
180-
181164
// CustomProviderKey is internally set by Bifrost, no validation needed
182165
baseProvider = req.CustomProviderConfig.BaseProviderType
183166
}
@@ -219,6 +202,14 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) {
219202
CustomProviderConfig: req.CustomProviderConfig,
220203
}
221204

205+
// Validate custom provider configuration (including AllowedRequests sanitization)
206+
if req.CustomProviderConfig != nil {
207+
if err := lib.ValidateCustomProvider(config, req.Provider); err != nil {
208+
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid custom provider config: %v", err), h.logger)
209+
return
210+
}
211+
}
212+
222213
// Add provider to store (env vars will be processed by store)
223214
if err := h.store.AddProvider(req.Provider, config); err != nil {
224215
h.logger.Warn(fmt.Sprintf("Failed to add provider %s: %v", req.Provider, err))

transports/bifrost-http/lib/config.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,26 @@ func ValidateCustomProvider(config configstore.ProviderConfig, provider schemas.
18501850
if !bifrost.IsSupportedBaseProvider(cpc.BaseProviderType) {
18511851
return fmt.Errorf("custom provider validation failed: unsupported base_provider_type: %s", cpc.BaseProviderType)
18521852
}
1853+
1854+
// Validate allowed requests against base provider capabilities
1855+
if cpc.AllowedRequests != nil {
1856+
schemas.ValidateAllowedRequests(cpc.AllowedRequests, cpc.BaseProviderType)
1857+
} else {
1858+
// Set default allowed requests based on provider capabilities when nil
1859+
// This ensures the auto-disable feature works even when no explicit restrictions are set
1860+
caps := schemas.GetProviderCapabilities(cpc.BaseProviderType)
1861+
cpc.AllowedRequests = &schemas.AllowedRequests{
1862+
TextCompletion: caps.TextCompletion,
1863+
ChatCompletion: caps.ChatCompletion,
1864+
ChatCompletionStream: caps.ChatCompletionStream,
1865+
Embedding: caps.Embedding,
1866+
Speech: caps.Speech,
1867+
SpeechStream: caps.SpeechStream,
1868+
Transcription: caps.Transcription,
1869+
TranscriptionStream: caps.TranscriptionStream,
1870+
}
1871+
}
1872+
18531873
return nil
18541874
}
18551875

@@ -1882,7 +1902,8 @@ func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.Provider
18821902
provider, existingCPC.BaseProviderType, newCPC.BaseProviderType)
18831903
}
18841904

1885-
return nil
1905+
// Validate the new config (including AllowedRequests sanitization)
1906+
return ValidateCustomProvider(newConfig, provider)
18861907
}
18871908

18881909
func (s *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error {

ui/app/providers/configure/page.tsx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"use client";
22

3-
import { closeConfigureDialog, useAppDispatch, useAppSelector, useGetProvidersQuery } from "@/lib/store";
3+
import { closeConfigureDialog, setSelectedProvider, useAppDispatch, useAppSelector, useGetProvidersQuery } from "@/lib/store";
44
import { useRouter } from "next/navigation";
5+
import { useEffect } from "react";
56
import ProviderForm from "./provider-form";
67

78
export default function ConfigurePage() {
@@ -10,6 +11,16 @@ export default function ConfigurePage() {
1011
const selectedProvider = useAppSelector((state) => state.provider.selectedProvider);
1112
const { data: providersData, refetch } = useGetProvidersQuery();
1213

14+
// Auto-select first available provider when no provider is selected and data is loaded
15+
useEffect(() => {
16+
if (!selectedProvider && providersData?.providers && providersData.providers.length > 0) {
17+
const providerToSelect = providersData.providers[0];
18+
19+
// Update Redux state to select this provider
20+
dispatch(setSelectedProvider(providerToSelect));
21+
}
22+
}, [selectedProvider, providersData?.providers, dispatch]);
23+
1324
const handleSave = () => {
1425
refetch();
1526
dispatch(closeConfigureDialog());

0 commit comments

Comments
 (0)