Skip to content

Commit 6d784ea

Browse files
committed
fix: base provider level request level restriction
1 parent 2ac4f85 commit 6d784ea

File tree

6 files changed

+297
-9
lines changed

6 files changed

+297
-9
lines changed

core/schemas/provider.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,117 @@ func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool {
123123
}
124124
}
125125

126+
// ProviderCapabilities defines what operations each provider supports
127+
type ProviderCapabilities struct {
128+
TextCompletion bool
129+
ChatCompletion bool
130+
ChatCompletionStream bool
131+
Embedding bool
132+
Speech bool
133+
SpeechStream bool
134+
Transcription bool
135+
TranscriptionStream bool
136+
}
137+
138+
// GetProviderCapabilities returns the capabilities for a given provider
139+
func GetProviderCapabilities(provider ModelProvider) ProviderCapabilities {
140+
switch provider {
141+
case OpenAI:
142+
return ProviderCapabilities{
143+
TextCompletion: false,
144+
ChatCompletion: true,
145+
ChatCompletionStream: true,
146+
Embedding: true,
147+
Speech: true,
148+
SpeechStream: true,
149+
Transcription: true,
150+
TranscriptionStream: true,
151+
}
152+
case Anthropic:
153+
return ProviderCapabilities{
154+
TextCompletion: true,
155+
ChatCompletion: true,
156+
ChatCompletionStream: true,
157+
Embedding: false,
158+
Speech: false,
159+
SpeechStream: false,
160+
Transcription: false,
161+
TranscriptionStream: false,
162+
}
163+
case Cohere:
164+
return ProviderCapabilities{
165+
TextCompletion: false,
166+
ChatCompletion: true,
167+
ChatCompletionStream: true,
168+
Embedding: true,
169+
Speech: false,
170+
SpeechStream: false,
171+
Transcription: false,
172+
TranscriptionStream: false,
173+
}
174+
case Bedrock:
175+
return ProviderCapabilities{
176+
TextCompletion: true,
177+
ChatCompletion: true,
178+
ChatCompletionStream: true,
179+
Embedding: true,
180+
Speech: false,
181+
SpeechStream: false,
182+
Transcription: false,
183+
TranscriptionStream: false,
184+
}
185+
default:
186+
// For unknown providers, assume all capabilities (fallback)
187+
return ProviderCapabilities{
188+
TextCompletion: true,
189+
ChatCompletion: true,
190+
ChatCompletionStream: true,
191+
Embedding: true,
192+
Speech: true,
193+
SpeechStream: true,
194+
Transcription: true,
195+
TranscriptionStream: true,
196+
}
197+
}
198+
}
199+
200+
// ValidateAllowedRequests checks if the allowed requests are valid for the given base provider
201+
// and silently sets any unsupported operations to false instead of returning an error
202+
func ValidateAllowedRequests(allowedRequests *AllowedRequests, baseProvider ModelProvider) {
203+
if allowedRequests == nil {
204+
// nil means all operations allowed, which is always valid
205+
return
206+
}
207+
208+
capabilities := GetProviderCapabilities(baseProvider)
209+
210+
// Silently set unsupported operations to false instead of returning an error
211+
if allowedRequests.TextCompletion && !capabilities.TextCompletion {
212+
allowedRequests.TextCompletion = false
213+
}
214+
if allowedRequests.ChatCompletion && !capabilities.ChatCompletion {
215+
allowedRequests.ChatCompletion = false
216+
}
217+
if allowedRequests.ChatCompletionStream && !capabilities.ChatCompletionStream {
218+
allowedRequests.ChatCompletionStream = false
219+
}
220+
if allowedRequests.Embedding && !capabilities.Embedding {
221+
allowedRequests.Embedding = false
222+
}
223+
if allowedRequests.Speech && !capabilities.Speech {
224+
allowedRequests.Speech = false
225+
}
226+
if allowedRequests.SpeechStream && !capabilities.SpeechStream {
227+
allowedRequests.SpeechStream = false
228+
}
229+
if allowedRequests.Transcription && !capabilities.Transcription {
230+
allowedRequests.Transcription = false
231+
}
232+
if allowedRequests.TranscriptionStream && !capabilities.TranscriptionStream {
233+
allowedRequests.TranscriptionStream = false
234+
}
235+
}
236+
126237
type CustomProviderConfig struct {
127238
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
128239
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type

docs/features/custom-providers.mdx

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,73 @@ curl --location 'http://localhost:8080/api/providers' \
107107
}
108108
```
109109

110+
</Tab>
111+
112+
<Tab title="Using GO SDK">
113+
114+
```go
115+
type MyAccount struct{}
116+
117+
OpenAICustom = schemas.ModelProvider("openai-custom")
118+
119+
func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
120+
return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, OpenAICustom}, nil
121+
}
122+
123+
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
124+
switch provider {
125+
case schemas.OpenAI:
126+
return []schemas.Key{{
127+
Value: os.Getenv("OPENAI_API_KEY"),
128+
Models: []string{},
129+
Weight: 1.0,
130+
}}, nil
131+
case schemas.Anthropic:
132+
return []schemas.Key{{
133+
Value: os.Getenv("ANTHROPIC_API_KEY"),
134+
Models: []string{},
135+
Weight: 1.0,
136+
}}, nil
137+
case OpenAICustom:
138+
return []schemas.Key{{
139+
Value: os.Getenv("PROVIDER_API_KEY"),
140+
Models: []string{},
141+
Weight: 1.0,
142+
}}, nil
143+
}
144+
return nil, fmt.Errorf("provider %s not supported", provider)
145+
}
146+
147+
func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) {
148+
switch provider {
149+
case OpenAICustom:
150+
// Return config with base provider and allowed requests
151+
return &schemas.ProviderConfig{
152+
NetworkConfig: schemas.DefaultNetworkConfig,
153+
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
154+
BaseProviderType: schemas.OpenAI,
155+
AllowedRequests: &schemas.AllowedRequests{
156+
TextCompletion: false,
157+
ChatCompletion: true,
158+
ChatCompletionStream: true,
159+
Embedding: true,
160+
Speech: false,
161+
SpeechStream: false,
162+
Transcription: false,
163+
TranscriptionStream: false,
164+
},
165+
}, nil
166+
default:
167+
// Return default config for other providers
168+
return &schemas.ProviderConfig{
169+
NetworkConfig: schemas.DefaultNetworkConfig,
170+
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
171+
}, nil
172+
}
173+
}
174+
```
175+
176+
110177
</Tab>
111178

112179
</Tabs>

transports/bifrost-http/handlers/providers.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) {
178178
return
179179
}
180180

181+
// Validate allowed requests against base provider capabilities
182+
if req.CustomProviderConfig.AllowedRequests != nil {
183+
schemas.ValidateAllowedRequests(req.CustomProviderConfig.AllowedRequests, req.CustomProviderConfig.BaseProviderType)
184+
}
185+
181186
// CustomProviderKey is internally set by Bifrost, no validation needed
182187
baseProvider = req.CustomProviderConfig.BaseProviderType
183188
}

transports/bifrost-http/lib/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,12 @@ func ValidateCustomProvider(config configstore.ProviderConfig, provider schemas.
18411841
if !bifrost.IsSupportedBaseProvider(cpc.BaseProviderType) {
18421842
return fmt.Errorf("custom provider validation failed: unsupported base_provider_type: %s", cpc.BaseProviderType)
18431843
}
1844+
1845+
// Validate allowed requests against base provider capabilities
1846+
if cpc.AllowedRequests != nil {
1847+
schemas.ValidateAllowedRequests(cpc.AllowedRequests, cpc.BaseProviderType)
1848+
}
1849+
18441850
return nil
18451851
}
18461852

ui/app/providers/configure/provider-form.tsx

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
1717
import { TagInput } from "@/components/ui/tag-input";
1818
import { Textarea } from "@/components/ui/textarea";
1919
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
20-
import { DEFAULT_ALLOWED_REQUESTS, DEFAULT_NETWORK_CONFIG, DEFAULT_PERFORMANCE_CONFIG } from "@/lib/constants/config";
20+
import { DEFAULT_NETWORK_CONFIG, DEFAULT_PERFORMANCE_CONFIG, getProviderDefaultAllowedRequests } from "@/lib/constants/config";
2121
import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons";
2222
import { PROVIDER_LABELS, PROVIDERS as Providers } from "@/lib/constants/logs";
2323
import { getErrorMessage, useCreateProviderMutation, useUpdateProviderMutation } from "@/lib/store";
@@ -111,11 +111,12 @@ const createInitialState = (provider?: ProviderResponse | null, defaultProvider?
111111

112112
// Check if this is a custom provider
113113
const isCustomProvider = provider && !Providers.includes(provider.name as any);
114+
const baseProviderType = provider?.custom_provider_config?.base_provider_type || "";
114115

115116
return {
116117
selectedProvider: providerName,
117118
customProviderName: isCustomProvider ? provider.name : "",
118-
baseProviderType: provider?.custom_provider_config?.base_provider_type || "",
119+
baseProviderType,
119120
keys: isNewProvider && keysRequired ? [createDefaultKey()] : !isNewProvider && keysRequired && provider?.keys ? provider.keys : [],
120121
networkConfig: provider?.network_config || DEFAULT_NETWORK_CONFIG,
121122
performanceConfig: provider?.concurrency_and_buffer_size || DEFAULT_PERFORMANCE_CONFIG,
@@ -126,7 +127,7 @@ const createInitialState = (provider?: ProviderResponse | null, defaultProvider?
126127
password: "",
127128
},
128129
sendBackRawResponse: provider?.send_back_raw_response || false,
129-
allowedRequests: provider?.custom_provider_config?.allowed_requests || DEFAULT_ALLOWED_REQUESTS,
130+
allowedRequests: provider?.custom_provider_config?.allowed_requests || getProviderDefaultAllowedRequests(baseProviderType),
130131
};
131132
};
132133

@@ -189,6 +190,18 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
189190
const isCustomProvider =
190191
selectedProvider === "custom" || !!customProviderName || !!baseProviderType || !Providers.includes(selectedProvider as any);
191192

193+
// Update allowed requests when base provider type changes for custom providers
194+
useEffect(() => {
195+
if (baseProviderType && isCustomProvider && !isEditingExisting) {
196+
const newAllowedRequests = getProviderDefaultAllowedRequests(baseProviderType);
197+
setFormData((prev) => ({
198+
...prev,
199+
allowedRequests: newAllowedRequests,
200+
isDirty: true,
201+
}));
202+
}
203+
}, [baseProviderType, isCustomProvider, isEditingExisting]);
204+
192205
const performanceValid =
193206
performanceConfig.concurrency > 0 && performanceConfig.buffer_size > 0 && performanceConfig.concurrency < performanceConfig.buffer_size;
194207

@@ -410,10 +423,15 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
410423
/^[a-z0-9_-]+$/,
411424
"Custom provider name must be lowercase alphanumeric and may include ‘-’ or ‘_’ (no spaces)",
412425
),
413-
Validator.custom(
414-
!allProviders.some((p) => p.name === customProviderName.trim() && p.name !== (provider?.name || "")),
415-
"A provider with this name already exists",
416-
),
426+
// Only check for duplicate names when creating new providers, not when updating existing ones
427+
...(provider
428+
? []
429+
: [
430+
Validator.custom(
431+
!allProviders.some((p) => p.name === customProviderName.trim()),
432+
"A provider with this name already exists",
433+
),
434+
]),
417435
Validator.required(baseProviderType, "Base provider type is required for custom providers"),
418436
Validator.custom(
419437
!Providers.includes(customProviderName.trim() as any),
@@ -646,7 +664,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
646664
password: "",
647665
},
648666
sendBackRawResponse: false,
649-
allowedRequests: DEFAULT_ALLOWED_REQUESTS,
667+
allowedRequests: getProviderDefaultAllowedRequests("openai"),
650668
});
651669

652670
const [selectedTab, setSelectedTab] = useState(tabs[0]?.id || "api-keys");
@@ -843,7 +861,10 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
843861
{/* Allowed Requests Configuration */}
844862
<div className="space-y-2">
845863
<div className="text-sm font-medium">Allowed Request Types</div>
846-
<p className="text-muted-foreground text-xs">Select which request types this custom provider can handle</p>
864+
<p className="text-muted-foreground text-xs">
865+
Select which request types this custom provider can handle. Automatically disabled for features not supported by the
866+
underlying provider.
867+
</p>
847868

848869
<div className="grid grid-cols-2 gap-4">
849870
<div className="space-y-3">
@@ -853,6 +874,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
853874
size="md"
854875
checked={allowedRequests.text_completion}
855876
onCheckedChange={(checked) => updateAllowedRequest("text_completion", checked)}
877+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").text_completion}
856878
/>
857879
</div>
858880
<div className="flex items-center justify-between">
@@ -861,6 +883,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
861883
size="md"
862884
checked={allowedRequests.chat_completion}
863885
onCheckedChange={(checked) => updateAllowedRequest("chat_completion", checked)}
886+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").chat_completion}
864887
/>
865888
</div>
866889
<div className="flex items-center justify-between">
@@ -869,6 +892,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
869892
size="md"
870893
checked={allowedRequests.chat_completion_stream}
871894
onCheckedChange={(checked) => updateAllowedRequest("chat_completion_stream", checked)}
895+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").chat_completion_stream}
872896
/>
873897
</div>
874898
<div className="flex items-center justify-between">
@@ -877,6 +901,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
877901
size="md"
878902
checked={allowedRequests.embedding}
879903
onCheckedChange={(checked) => updateAllowedRequest("embedding", checked)}
904+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").embedding}
880905
/>
881906
</div>
882907
</div>
@@ -887,6 +912,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
887912
size="md"
888913
checked={allowedRequests.speech}
889914
onCheckedChange={(checked) => updateAllowedRequest("speech", checked)}
915+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").speech}
890916
/>
891917
</div>
892918
<div className="flex items-center justify-between">
@@ -895,6 +921,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
895921
size="md"
896922
checked={allowedRequests.speech_stream}
897923
onCheckedChange={(checked) => updateAllowedRequest("speech_stream", checked)}
924+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").speech_stream}
898925
/>
899926
</div>
900927
<div className="flex items-center justify-between">
@@ -903,6 +930,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
903930
size="md"
904931
checked={allowedRequests.transcription}
905932
onCheckedChange={(checked) => updateAllowedRequest("transcription", checked)}
933+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").transcription}
906934
/>
907935
</div>
908936
<div className="flex items-center justify-between">
@@ -911,6 +939,7 @@ export default function ProviderForm({ provider, onSave, onCancel, existingProvi
911939
size="md"
912940
checked={allowedRequests.transcription_stream}
913941
onCheckedChange={(checked) => updateAllowedRequest("transcription_stream", checked)}
942+
disabled={!getProviderDefaultAllowedRequests(baseProviderType || "openai").transcription_stream}
914943
/>
915944
</div>
916945
</div>

0 commit comments

Comments
 (0)