5
5
6
6
server : ServerProcess
7
7
8
- IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
9
- IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
10
-
11
- response = requests .get (IMG_URL_0 )
12
- response .raise_for_status () # Raise an exception for bad status codes
13
- IMG_BASE64_URI_0 = "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
14
- IMG_BASE64_0 = base64 .b64encode (response .content ).decode ("utf-8" )
15
-
16
- response = requests .get (IMG_URL_1 )
17
- response .raise_for_status () # Raise an exception for bad status codes
18
- IMG_BASE64_URI_1 = "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
19
- IMG_BASE64_1 = base64 .b64encode (response .content ).decode ("utf-8" )
8
+ def get_img_url (id : str ) -> str :
9
+ IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
10
+ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
11
+ if id == "IMG_URL_0" :
12
+ return IMG_URL_0
13
+ elif id == "IMG_URL_1" :
14
+ return IMG_URL_1
15
+ elif id == "IMG_BASE64_URI_0" :
16
+ response = requests .get (IMG_URL_0 )
17
+ response .raise_for_status () # Raise an exception for bad status codes
18
+ return "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
19
+ elif id == "IMG_BASE64_0" :
20
+ response = requests .get (IMG_URL_0 )
21
+ response .raise_for_status () # Raise an exception for bad status codes
22
+ return base64 .b64encode (response .content ).decode ("utf-8" )
23
+ elif id == "IMG_BASE64_URI_1" :
24
+ response = requests .get (IMG_URL_1 )
25
+ response .raise_for_status () # Raise an exception for bad status codes
26
+ return "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
27
+ elif id == "IMG_BASE64_1" :
28
+ response = requests .get (IMG_URL_1 )
29
+ response .raise_for_status () # Raise an exception for bad status codes
30
+ return base64 .b64encode (response .content ).decode ("utf-8" )
31
+ else :
32
+ return id
20
33
21
34
JSON_MULTIMODAL_KEY = "multimodal_data"
22
35
JSON_PROMPT_STRING_KEY = "prompt_string"
@@ -28,7 +41,7 @@ def create_server():
28
41
29
42
def test_models_supports_multimodal_capability ():
30
43
global server
31
- server .start () # vision model may take longer to load due to download size
44
+ server .start ()
32
45
res = server .make_request ("GET" , "/models" , data = {})
33
46
assert res .status_code == 200
34
47
model_info = res .body ["models" ][0 ]
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
38
51
39
52
def test_v1_models_supports_multimodal_capability ():
40
53
global server
41
- server .start () # vision model may take longer to load due to download size
54
+ server .start ()
42
55
res = server .make_request ("GET" , "/v1/models" , data = {})
43
56
assert res .status_code == 200
44
57
model_info = res .body ["models" ][0 ]
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
50
63
"prompt, image_url, success, re_content" ,
51
64
[
52
65
# test model is trained on CIFAR-10, but it's quite dumb due to small size
53
- ("What is this:\n " , IMG_URL_0 , True , "(cat)+" ),
54
- ("What is this:\n " , "IMG_BASE64_URI_0" , True , "(cat)+" ), # exceptional, so that we don't cog up the log
55
- ("What is this:\n " , IMG_URL_1 , True , "(frog)+" ),
56
- ("Test test\n " , IMG_URL_1 , True , "(frog)+" ), # test invalidate cache
66
+ ("What is this:\n " , " IMG_URL_0" , True , "(cat)+" ),
67
+ ("What is this:\n " , "IMG_BASE64_URI_0" , True , "(cat)+" ),
68
+ ("What is this:\n " , " IMG_URL_1" , True , "(frog)+" ),
69
+ ("Test test\n " , " IMG_URL_1" , True , "(frog)+" ), # test invalidate cache
57
70
("What is this:\n " , "malformed" , False , None ),
58
71
("What is this:\n " , "https://google.com/404" , False , None ), # non-existent image
59
72
("What is this:\n " , "https://ggml.ai" , False , None ), # non-image data
@@ -62,17 +75,15 @@ def test_v1_models_supports_multimodal_capability():
62
75
)
63
76
def test_vision_chat_completion (prompt , image_url , success , re_content ):
64
77
global server
65
- server .start (timeout_seconds = 60 ) # vision model may take longer to load due to download size
66
- if image_url == "IMG_BASE64_URI_0" :
67
- image_url = IMG_BASE64_URI_0
78
+ server .start ()
68
79
res = server .make_request ("POST" , "/chat/completions" , data = {
69
80
"temperature" : 0.0 ,
70
81
"top_k" : 1 ,
71
82
"messages" : [
72
83
{"role" : "user" , "content" : [
73
84
{"type" : "text" , "text" : prompt },
74
85
{"type" : "image_url" , "image_url" : {
75
- "url" : image_url ,
86
+ "url" : get_img_url ( image_url ) ,
76
87
}},
77
88
]},
78
89
],
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
90
101
"prompt, image_data, success, re_content" ,
91
102
[
92
103
# test model is trained on CIFAR-10, but it's quite dumb due to small size
93
- ("What is this: <__media__>\n " , IMG_BASE64_0 , True , "(cat)+" ),
94
- ("What is this: <__media__>\n " , IMG_BASE64_1 , True , "(frog)+" ),
104
+ ("What is this: <__media__>\n " , " IMG_BASE64_0" , True , "(cat)+" ),
105
+ ("What is this: <__media__>\n " , " IMG_BASE64_1" , True , "(frog)+" ),
95
106
("What is this: <__media__>\n " , "malformed" , False , None ), # non-image data
96
107
("What is this:\n " , "" , False , None ), # empty string
97
108
]
98
109
)
99
110
def test_vision_completion (prompt , image_data , success , re_content ):
100
111
global server
101
- server .start () # vision model may take longer to load due to download size
112
+ server .start ()
102
113
res = server .make_request ("POST" , "/completions" , data = {
103
114
"temperature" : 0.0 ,
104
115
"top_k" : 1 ,
105
- "prompt" : { JSON_PROMPT_STRING_KEY : prompt , JSON_MULTIMODAL_KEY : [ image_data ] },
116
+ "prompt" : {
117
+ JSON_PROMPT_STRING_KEY : prompt ,
118
+ JSON_MULTIMODAL_KEY : [ get_img_url (image_data ) ],
119
+ },
106
120
})
107
121
if success :
108
122
assert res .status_code == 200
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
116
130
"prompt, image_data, success" ,
117
131
[
118
132
# test model is trained on CIFAR-10, but it's quite dumb due to small size
119
- ("What is this: <__media__>\n " , IMG_BASE64_0 , True ), # exceptional, so that we don't cog up the log
120
- ("What is this: <__media__>\n " , IMG_BASE64_1 , True ),
133
+ ("What is this: <__media__>\n " , " IMG_BASE64_0" , True ),
134
+ ("What is this: <__media__>\n " , " IMG_BASE64_1" , True ),
121
135
("What is this: <__media__>\n " , "malformed" , False ), # non-image data
122
136
("What is this:\n " , "base64" , False ), # non-image data
123
137
]
124
138
)
125
139
def test_vision_embeddings (prompt , image_data , success ):
126
140
global server
127
- server .server_embeddings = True
128
- server .n_batch = 512
129
- server .start () # vision model may take longer to load due to download size
141
+ server .server_embeddings = True
142
+ server .n_batch = 512
143
+ server .start ()
144
+ image_data = get_img_url (image_data )
130
145
res = server .make_request ("POST" , "/embeddings" , data = {
131
146
"content" : [
132
147
{ JSON_PROMPT_STRING_KEY : prompt , JSON_MULTIMODAL_KEY : [ image_data ] },
0 commit comments