Skip to content

Commit ddcfbfa

Browse files
adds auth for web ui
1 parent d475f2b commit ddcfbfa

File tree

80 files changed

+1812
-230
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1812
-230
lines changed

core/bifrost.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8989
if config.Logger == nil {
9090
config.Logger = NewDefaultLogger(schemas.LogLevelInfo)
9191
}
92-
92+
9393
providerUtils.SetLogger(config.Logger)
9494

9595
bifrost := &Bifrost{
@@ -338,10 +338,10 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
338338
for {
339339
// check for context cancellation
340340
select {
341-
case <-ctx.Done():
342-
bifrost.logger.Warn(fmt.Sprintf("context cancelled for provider %s", providerKey))
343-
return
344-
default:
341+
case <-ctx.Done():
342+
bifrost.logger.Warn(fmt.Sprintf("context cancelled for provider %s", providerKey))
343+
return
344+
default:
345345
}
346346

347347
iterations++

core/providers/anthropic/anthropic.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
7777
client := &fasthttp.Client{
7878
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
7979
WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
80-
MaxConnsPerHost: 10000,
80+
MaxConnsPerHost: 1024,
8181
MaxIdleConnDuration: 60 * time.Second,
8282
MaxConnWaitTimeout: 10 * time.Second,
8383
}
@@ -453,6 +453,9 @@ func HandleAnthropicChatCompletionStreaming(
453453
}
454454

455455
scanner := bufio.NewScanner(resp.BodyStream())
456+
buf := make([]byte, 0, 1024*1024)
457+
scanner.Buffer(buf, 10*1024*1024)
458+
456459
chunkIndex := 0
457460

458461
startTime := time.Now()

core/providers/azure/azure.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
3333
config.CheckAndSetDefaults()
3434

3535
client := &fasthttp.Client{
36-
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
37-
WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
38-
MaxConnsPerHost: 10000,
36+
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
37+
WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
38+
MaxConnsPerHost: 1024,
3939
MaxIdleConnDuration: 60 * time.Second,
4040
MaxConnWaitTimeout: 10 * time.Second,
4141
}
@@ -471,8 +471,11 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
471471
if deployment == "" {
472472
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey())
473473
}
474-
475-
url := fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint)
474+
apiVersion := key.AzureKeyConfig.APIVersion
475+
if apiVersion == nil {
476+
apiVersion = schemas.Ptr(AzureAPIVersionPreview)
477+
}
478+
url := fmt.Sprintf("%s/openai/v1/responses?api-version=%s", key.AzureKeyConfig.Endpoint, *apiVersion)
476479

477480
// Prepare Azure-specific headers
478481
authHeader := make(map[string]string)

core/providers/azure/types.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package azure
22

3-
// DefaultAzureAPIVersion is the default Azure OpenAI API version to use when not specified.
4-
const DefaultAzureAPIVersion = "2024-10-21"
3+
// AzureAPIVersionDefault is the default Azure OpenAI API version to use when not specified.
4+
const AzureAPIVersionDefault = "2024-10-21"
5+
const AzureAPIVersionPreview = "preview"
56

67
type AzureModelCapabilities struct {
78
FineTune bool `json:"fine_tune"`

core/providers/bedrock/bedrock.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
2020
"github.com/aws/aws-sdk-go-v2/config"
2121
"github.com/bytedance/sonic"
22+
"github.com/maximhq/bifrost/core/providers/anthropic"
23+
"github.com/maximhq/bifrost/core/providers/cohere"
2224
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
2325
schemas "github.com/maximhq/bifrost/core/schemas"
2426
)

core/providers/bedrock/chat.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/bytedance/sonic"
88
"github.com/google/uuid"
9+
"github.com/maximhq/bifrost/core/providers/anthropic"
910
"github.com/maximhq/bifrost/core/schemas"
1011
)
1112

core/providers/bedrock/signer.go

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"crypto/sha256"
77
"encoding/hex"
88
"fmt"
9-
"net/url"
109
"sort"
1110
"strconv"
1211
"strings"
@@ -124,6 +123,105 @@ func stripExcessSpaces(str string) string {
124123
return result.String()
125124
}
126125

126+
// percentEncodeRFC3986 encodes a string per RFC 3986
127+
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
128+
// Percent-encode everything else as %HH using uppercase hex
129+
func percentEncodeRFC3986(s string) string {
130+
var result strings.Builder
131+
result.Grow(len(s))
132+
133+
for i := 0; i < len(s); i++ {
134+
b := s[i]
135+
// RFC 3986 unreserved characters
136+
if (b >= 'A' && b <= 'Z') ||
137+
(b >= 'a' && b <= 'z') ||
138+
(b >= '0' && b <= '9') ||
139+
b == '-' || b == '_' || b == '.' || b == '~' {
140+
result.WriteByte(b)
141+
} else {
142+
// Percent-encode with uppercase hex
143+
result.WriteByte('%')
144+
result.WriteByte(uppercaseHex(b >> 4))
145+
result.WriteByte(uppercaseHex(b & 0x0F))
146+
}
147+
}
148+
149+
return result.String()
150+
}
151+
152+
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
153+
func uppercaseHex(b byte) byte {
154+
if b < 10 {
155+
return '0' + b
156+
}
157+
return 'A' + (b - 10)
158+
}
159+
160+
// queryPair represents a query parameter name-value pair
161+
type queryPair struct {
162+
encodedName string
163+
encodedValue string
164+
}
165+
166+
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
167+
// using proper RFC 3986 percent-encoding
168+
func buildCanonicalQueryString(queryString string) string {
169+
if queryString == "" {
170+
return ""
171+
}
172+
173+
// Split the raw query string on '&' into pairs
174+
rawPairs := strings.Split(queryString, "&")
175+
pairs := make([]queryPair, 0, len(rawPairs))
176+
177+
for _, rawPair := range rawPairs {
178+
if rawPair == "" {
179+
continue
180+
}
181+
182+
// Split on the first '=' to get name and value
183+
var name, value string
184+
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
185+
name = rawPair[:idx]
186+
value = rawPair[idx+1:]
187+
} else {
188+
// No '=' means name only, empty value
189+
name = rawPair
190+
value = ""
191+
}
192+
193+
// Percent-encode name and value per RFC 3986
194+
encodedName := percentEncodeRFC3986(name)
195+
encodedValue := percentEncodeRFC3986(value)
196+
197+
pairs = append(pairs, queryPair{
198+
encodedName: encodedName,
199+
encodedValue: encodedValue,
200+
})
201+
}
202+
203+
// Sort pairs lexicographically by encoded name, then by encoded value
204+
sort.Slice(pairs, func(i, j int) bool {
205+
if pairs[i].encodedName != pairs[j].encodedName {
206+
return pairs[i].encodedName < pairs[j].encodedName
207+
}
208+
return pairs[i].encodedValue < pairs[j].encodedValue
209+
})
210+
211+
// Join encoded pairs with '&'
212+
var result strings.Builder
213+
for i, pair := range pairs {
214+
if i > 0 {
215+
result.WriteByte('&')
216+
}
217+
result.WriteString(pair.encodedName)
218+
result.WriteByte('=')
219+
result.WriteString(pair.encodedValue)
220+
}
221+
222+
return result.String()
223+
}
224+
127225
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
128226
// This is a native implementation that avoids allocating http.Request
129227
func signAWSRequestFastHTTP(
@@ -191,9 +289,9 @@ func signAWSRequestFastHTTP(
191289
headerMap["host"] = []string{host}
192290

193291
// Include content-length if body is present
194-
if len(body) > 0 {
292+
if cl := req.Header.ContentLength(); cl >= 0 {
195293
headerNames = append(headerNames, "content-length")
196-
headerMap["content-length"] = []string{strconv.Itoa(len(body))}
294+
headerMap["content-length"] = []string{strconv.Itoa(cl)}
197295
}
198296

199297
// Collect other headers
@@ -238,25 +336,8 @@ func signAWSRequestFastHTTP(
238336

239337
signedHeaders := strings.Join(headerNames, ";")
240338

241-
// Parse and normalize query string
242-
var canonicalQueryString string
243-
if queryString != "" {
244-
values, _ := url.ParseQuery(queryString)
245-
// Sort keys
246-
keys := make([]string, 0, len(values))
247-
for k := range values {
248-
keys = append(keys, k)
249-
}
250-
sort.Strings(keys)
251-
252-
// Sort values for each key
253-
for _, k := range keys {
254-
sort.Strings(values[k])
255-
}
256-
257-
canonicalQueryString = values.Encode()
258-
canonicalQueryString = strings.ReplaceAll(canonicalQueryString, "+", "%20")
259-
}
339+
// Build canonical query string using RFC 3986 encoding
340+
canonicalQueryString := buildCanonicalQueryString(queryString)
260341

261342
// Build canonical request
262343
canonicalRequest := strings.Join([]string{

0 commit comments

Comments
 (0)