@@ -31,6 +31,9 @@ import (
3131 "github.com/sashabaranov/go-openai/jsonschema"
3232)
3333
34+ const apiKey = "joshua"
35+ const bearerKey = "Bearer " + apiKey
36+
3437const testPrompt = `### System:
3538You are an AI assistant that follows instruction extremely well. Help as much as you can.
3639
@@ -50,11 +53,19 @@ type modelApplyRequest struct {
5053
5154func getModelStatus (url string ) (response map [string ]interface {}) {
5255 // Create the HTTP request
53- resp , err := http .Get (url )
56+ req , err := http .NewRequest ("GET" , url , nil )
57+ req .Header .Set ("Content-Type" , "application/json" )
58+ req .Header .Set ("Authorization" , bearerKey )
5459 if err != nil {
5560 fmt .Println ("Error creating request:" , err )
5661 return
5762 }
63+ client := & http.Client {}
64+ resp , err := client .Do (req )
65+ if err != nil {
66+ fmt .Println ("Error sending request:" , err )
67+ return
68+ }
5869 defer resp .Body .Close ()
5970
6071 body , err := io .ReadAll (resp .Body )
@@ -72,14 +83,15 @@ func getModelStatus(url string) (response map[string]interface{}) {
7283 return
7384}
7485
75- func getModels (url string ) (response []gallery.GalleryModel ) {
86+ func getModels (url string ) ([]gallery.GalleryModel , error ) {
87+ response := []gallery.GalleryModel {}
7688 uri := downloader .URI (url )
7789 // TODO: No tests currently seem to exercise file:// urls. Fix?
78- uri .DownloadAndUnmarshal ("" , func (url string , i []byte ) error {
90+ err := uri .DownloadWithAuthorizationAndCallback ("" , bearerKey , func (url string , i []byte ) error {
7991 // Unmarshal YAML data into a struct
8092 return json .Unmarshal (i , & response )
8193 })
82- return
94+ return response , err
8395}
8496
8597func postModelApplyRequest (url string , request modelApplyRequest ) (response map [string ]interface {}) {
@@ -101,6 +113,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
101113 return
102114 }
103115 req .Header .Set ("Content-Type" , "application/json" )
116+ req .Header .Set ("Authorization" , bearerKey )
104117
105118 // Make the request
106119 client := & http.Client {}
@@ -140,6 +153,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
140153 }
141154
142155 req .Header .Set ("Content-Type" , "application/json" )
156+ req .Header .Set ("Authorization" , bearerKey )
143157
144158 client := & http.Client {}
145159 resp , err := client .Do (req )
@@ -175,6 +189,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
175189 }
176190
177191 req .Header .Set ("Content-Type" , "application/json" )
192+ req .Header .Set ("Authorization" , bearerKey )
178193
179194 client := & http.Client {}
180195 resp , err := client .Do (req )
@@ -195,6 +210,35 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
195210 return json .Unmarshal (body , respJson )
196211}
197212
213+ func postInvalidRequest (url string ) (error , int ) {
214+
215+ req , err := http .NewRequest ("POST" , url , bytes .NewBufferString ("invalid request" ))
216+ if err != nil {
217+ return err , - 1
218+ }
219+
220+ req .Header .Set ("Content-Type" , "application/json" )
221+
222+ client := & http.Client {}
223+ resp , err := client .Do (req )
224+ if err != nil {
225+ return err , - 1
226+ }
227+
228+ defer resp .Body .Close ()
229+
230+ body , err := io .ReadAll (resp .Body )
231+ if err != nil {
232+ return err , - 1
233+ }
234+
235+ if resp .StatusCode < 200 || resp .StatusCode >= 400 {
236+ return fmt .Errorf ("unexpected status code: %d, body: %s" , resp .StatusCode , string (body )), resp .StatusCode
237+ }
238+
239+ return nil , resp .StatusCode
240+ }
241+
198242//go:embed backend-assets/*
199243var backendAssets embed.FS
200244
@@ -260,6 +304,7 @@ var _ = Describe("API test", func() {
260304 config .WithContext (c ),
261305 config .WithGalleries (galleries ),
262306 config .WithModelPath (modelDir ),
307+ config .WithApiKeys ([]string {apiKey }),
263308 config .WithBackendAssets (backendAssets ),
264309 config .WithBackendAssetsOutput (backendAssetsDir ))... )
265310 Expect (err ).ToNot (HaveOccurred ())
@@ -269,7 +314,7 @@ var _ = Describe("API test", func() {
269314
270315 go app .Listen ("127.0.0.1:9090" )
271316
272- defaultConfig := openai .DefaultConfig ("" )
317+ defaultConfig := openai .DefaultConfig (apiKey )
273318 defaultConfig .BaseURL = "http://127.0.0.1:9090/v1"
274319
275320 client2 = openaigo .NewClient ("" )
@@ -295,10 +340,19 @@ var _ = Describe("API test", func() {
295340 Expect (err ).To (HaveOccurred ())
296341 })
297342
343+ Context ("Auth Tests" , func () {
344+ It ("Should fail if the api key is missing" , func () {
345+ err , sc := postInvalidRequest ("http://127.0.0.1:9090/models/available" )
346+ Expect (err ).ToNot (BeNil ())
347+ Expect (sc ).To (Equal (403 ))
348+ })
349+ })
350+
298351 Context ("Applying models" , func () {
299352
300353 It ("applies models from a gallery" , func () {
301- models := getModels ("http://127.0.0.1:9090/models/available" )
354+ models , err := getModels ("http://127.0.0.1:9090/models/available" )
355+ Expect (err ).To (BeNil ())
302356 Expect (len (models )).To (Equal (2 ), fmt .Sprint (models ))
303357 Expect (models [0 ].Installed ).To (BeFalse (), fmt .Sprint (models ))
304358 Expect (models [1 ].Installed ).To (BeFalse (), fmt .Sprint (models ))
@@ -331,7 +385,8 @@ var _ = Describe("API test", func() {
331385 Expect (content ["backend" ]).To (Equal ("bert-embeddings" ))
332386 Expect (content ["foo" ]).To (Equal ("bar" ))
333387
334- models = getModels ("http://127.0.0.1:9090/models/available" )
388+ models , err = getModels ("http://127.0.0.1:9090/models/available" )
389+ Expect (err ).To (BeNil ())
335390 Expect (len (models )).To (Equal (2 ), fmt .Sprint (models ))
336391 Expect (models [0 ].Name ).To (Or (Equal ("bert" ), Equal ("bert2" )))
337392 Expect (models [1 ].Name ).To (Or (Equal ("bert" ), Equal ("bert2" )))
0 commit comments