@@ -123,6 +123,114 @@ func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool {
123123 }
124124}
125125
126+ // GetProviderCapabilities returns the capabilities for a given provider
127+ func GetProviderCapabilities (provider ModelProvider ) AllowedRequests {
128+ switch provider {
129+ case OpenAI :
130+ return AllowedRequests {
131+ TextCompletion : false ,
132+ ChatCompletion : true ,
133+ ChatCompletionStream : true ,
134+ Embedding : true ,
135+ Speech : true ,
136+ SpeechStream : true ,
137+ Transcription : true ,
138+ TranscriptionStream : true ,
139+ }
140+ case Anthropic :
141+ return AllowedRequests {
142+ TextCompletion : true ,
143+ ChatCompletion : true ,
144+ ChatCompletionStream : true ,
145+ Embedding : false ,
146+ Speech : false ,
147+ SpeechStream : false ,
148+ Transcription : false ,
149+ TranscriptionStream : false ,
150+ }
151+ case Cohere :
152+ return AllowedRequests {
153+ TextCompletion : false ,
154+ ChatCompletion : true ,
155+ ChatCompletionStream : true ,
156+ Embedding : true ,
157+ Speech : false ,
158+ SpeechStream : false ,
159+ Transcription : false ,
160+ TranscriptionStream : false ,
161+ }
162+ case Bedrock :
163+ return AllowedRequests {
164+ TextCompletion : true ,
165+ ChatCompletion : true ,
166+ ChatCompletionStream : true ,
167+ Embedding : true ,
168+ Speech : false ,
169+ SpeechStream : false ,
170+ Transcription : false ,
171+ TranscriptionStream : false ,
172+ }
173+ case Gemini :
174+ return AllowedRequests {
175+ TextCompletion : false ,
176+ ChatCompletion : true ,
177+ ChatCompletionStream : true ,
178+ Embedding : true ,
179+ Speech : true ,
180+ SpeechStream : true ,
181+ Transcription : true ,
182+ TranscriptionStream : true ,
183+ }
184+ default :
185+ // Unknown providers: safest default is no capabilities; upstream validation should reject these.
186+ return AllowedRequests {}
187+ }
188+ }
189+
190+ // GetDefaultAllowedRequestsForProvider returns a default AllowedRequests based on provider capabilities
191+ // This ensures that when no explicit restrictions are set, operations are limited to what the provider supports
192+ func GetDefaultAllowedRequestsForProvider (baseProvider ModelProvider ) * AllowedRequests {
193+ capabilities := GetProviderCapabilities (baseProvider )
194+ return & capabilities
195+ }
196+
197+ // ValidateAllowedRequests checks if the allowed requests are valid for the given base provider
198+ // and silently sets any unsupported operations to false instead of returning an error
199+ func ValidateAllowedRequests (allowedRequests * AllowedRequests , baseProvider ModelProvider ) {
200+ if allowedRequests == nil {
201+ // nil means all operations allowed, which is always valid
202+ return
203+ }
204+
205+ capabilities := GetProviderCapabilities (baseProvider )
206+
207+ // Silently set unsupported operations to false instead of returning an error
208+ if allowedRequests .TextCompletion && ! capabilities .TextCompletion {
209+ allowedRequests .TextCompletion = false
210+ }
211+ if allowedRequests .ChatCompletion && ! capabilities .ChatCompletion {
212+ allowedRequests .ChatCompletion = false
213+ }
214+ if allowedRequests .ChatCompletionStream && ! capabilities .ChatCompletionStream {
215+ allowedRequests .ChatCompletionStream = false
216+ }
217+ if allowedRequests .Embedding && ! capabilities .Embedding {
218+ allowedRequests .Embedding = false
219+ }
220+ if allowedRequests .Speech && ! capabilities .Speech {
221+ allowedRequests .Speech = false
222+ }
223+ if allowedRequests .SpeechStream && ! capabilities .SpeechStream {
224+ allowedRequests .SpeechStream = false
225+ }
226+ if allowedRequests .Transcription && ! capabilities .Transcription {
227+ allowedRequests .Transcription = false
228+ }
229+ if allowedRequests .TranscriptionStream && ! capabilities .TranscriptionStream {
230+ allowedRequests .TranscriptionStream = false
231+ }
232+ }
233+
126234type CustomProviderConfig struct {
127235 CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
128236 BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
@@ -131,8 +239,13 @@ type CustomProviderConfig struct {
131239
132240// IsOperationAllowed checks if a specific operation is allowed for this custom provider
133241func (cpc * CustomProviderConfig ) IsOperationAllowed (operation Operation ) bool {
134- if cpc == nil || cpc .AllowedRequests == nil {
135- return true // Default to allowed if no restrictions
242+ if cpc == nil {
243+ return true // Default to allowed if no custom provider config
244+ }
245+ if cpc .AllowedRequests == nil {
246+ // Use provider capabilities as default when no explicit restrictions are set
247+ defaultAllowedRequests := GetDefaultAllowedRequestsForProvider (cpc .BaseProviderType )
248+ return defaultAllowedRequests .IsOperationAllowed (operation )
136249 }
137250 return cpc .AllowedRequests .IsOperationAllowed (operation )
138251}
0 commit comments