Skip to content

Commit 3c3635d

Browse files
authored
server : speed up tests (#15836)
* server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline
1 parent 61bdfd5 commit 3c3635d

File tree

6 files changed

+90
-50
lines changed

6 files changed

+90
-50
lines changed

scripts/tool_bench.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
5454
if True:
5555
from tools.server.tests.utils import ServerProcess
56-
from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
56+
from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
5757

5858

5959
@contextmanager
@@ -335,7 +335,7 @@ def elapsed():
335335
# server.debug = True
336336

337337
with scoped_server(server):
338-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
338+
server.start(timeout_seconds=15 * 60)
339339
for ignore_chat_grammar in [False]:
340340
run(
341341
server,

tools/server/tests/unit/test_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
server = ServerPreset.tinyllama2()
66

77

8+
@pytest.fixture(scope="session", autouse=True)
9+
def do_something():
10+
# this will be run once per test session, before any tests
11+
ServerPreset.load_all()
12+
13+
814
@pytest.fixture(autouse=True)
915
def create_server():
1016
global server

tools/server/tests/unit/test_template.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@
1414

1515
server: ServerProcess
1616

17-
TIMEOUT_SERVER_START = 15*60
18-
1917
@pytest.fixture(autouse=True)
2018
def create_server():
2119
global server
2220
server = ServerPreset.tinyllama2()
2321
server.model_alias = "tinyllama-2"
24-
server.server_port = 8081
2522
server.n_slots = 1
2623

2724

@@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe
4542
server.jinja = True
4643
server.reasoning_budget = reasoning_budget
4744
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
48-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
45+
server.start()
4946

5047
res = server.make_request("POST", "/apply-template", data={
5148
"messages": [
@@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
6865
global server
6966
server.jinja = True
7067
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
71-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
68+
server.start()
7269

7370
res = server.make_request("POST", "/apply-template", data={
7471
"messages": [
@@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s
9188
global server
9289
server.jinja = True
9390
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
94-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
91+
server.start()
9592

9693
res = server.make_request("POST", "/apply-template", data={
9794
"messages": [

tools/server/tests/unit/test_tool_call.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
server: ServerProcess
1414

15-
TIMEOUT_SERVER_START = 15*60
15+
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
1616
TIMEOUT_HTTP_REQUEST = 60
1717

1818
@pytest.fixture(autouse=True)
@@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
124124
server.jinja = True
125125
server.n_predict = n_predict
126126
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
127-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
127+
server.start()
128128
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
129129

130130

@@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
168168
server.jinja = True
169169
server.n_predict = n_predict
170170
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
171-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
171+
server.start(timeout_seconds=TIMEOUT_START_SLOW)
172172
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
173173

174174

@@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
240240
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
241241
elif isinstance(template_override, str):
242242
server.chat_template = template_override
243-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
243+
server.start(timeout_seconds=TIMEOUT_START_SLOW)
244244
body = server.make_any_request("POST", "/v1/chat/completions", data={
245245
"max_tokens": n_predict,
246246
"messages": [
@@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
295295
server.n_predict = n_predict
296296
server.jinja = True
297297
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
298-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
298+
server.start()
299299
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
300300

301301

@@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
317317
server.n_predict = n_predict
318318
server.jinja = True
319319
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
320-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
320+
server.start(timeout_seconds=TIMEOUT_START_SLOW)
321321
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
322322

323323

@@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
377377
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
378378
elif isinstance(template_override, str):
379379
server.chat_template = template_override
380-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
380+
server.start()
381381
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
382382

383383

@@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
436436
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
437437
elif isinstance(template_override, str):
438438
server.chat_template = template_override
439-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
439+
server.start(timeout_seconds=TIMEOUT_START_SLOW)
440440
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
441441

442442

@@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
524524
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
525525
elif isinstance(template_override, str):
526526
server.chat_template = template_override
527-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
527+
server.start()
528528
body = server.make_any_request("POST", "/v1/chat/completions", data={
529529
"max_tokens": n_predict,
530530
"messages": [
@@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
597597
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
598598
elif isinstance(template_override, str):
599599
server.chat_template = template_override
600-
server.start(timeout_seconds=TIMEOUT_SERVER_START)
600+
server.start(timeout_seconds=TIMEOUT_START_SLOW)
601601

602602
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
603603

tools/server/tests/unit/test_vision_api.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,31 @@
55

66
server: ServerProcess
77

8-
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
9-
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
10-
11-
response = requests.get(IMG_URL_0)
12-
response.raise_for_status() # Raise an exception for bad status codes
13-
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
14-
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
15-
16-
response = requests.get(IMG_URL_1)
17-
response.raise_for_status() # Raise an exception for bad status codes
18-
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
19-
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
8+
def get_img_url(id: str) -> str:
9+
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
10+
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
11+
if id == "IMG_URL_0":
12+
return IMG_URL_0
13+
elif id == "IMG_URL_1":
14+
return IMG_URL_1
15+
elif id == "IMG_BASE64_URI_0":
16+
response = requests.get(IMG_URL_0)
17+
response.raise_for_status() # Raise an exception for bad status codes
18+
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
19+
elif id == "IMG_BASE64_0":
20+
response = requests.get(IMG_URL_0)
21+
response.raise_for_status() # Raise an exception for bad status codes
22+
return base64.b64encode(response.content).decode("utf-8")
23+
elif id == "IMG_BASE64_URI_1":
24+
response = requests.get(IMG_URL_1)
25+
response.raise_for_status() # Raise an exception for bad status codes
26+
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
27+
elif id == "IMG_BASE64_1":
28+
response = requests.get(IMG_URL_1)
29+
response.raise_for_status() # Raise an exception for bad status codes
30+
return base64.b64encode(response.content).decode("utf-8")
31+
else:
32+
return id
2033

2134
JSON_MULTIMODAL_KEY = "multimodal_data"
2235
JSON_PROMPT_STRING_KEY = "prompt_string"
@@ -28,7 +41,7 @@ def create_server():
2841

2942
def test_models_supports_multimodal_capability():
3043
global server
31-
server.start() # vision model may take longer to load due to download size
44+
server.start()
3245
res = server.make_request("GET", "/models", data={})
3346
assert res.status_code == 200
3447
model_info = res.body["models"][0]
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
3851

3952
def test_v1_models_supports_multimodal_capability():
4053
global server
41-
server.start() # vision model may take longer to load due to download size
54+
server.start()
4255
res = server.make_request("GET", "/v1/models", data={})
4356
assert res.status_code == 200
4457
model_info = res.body["models"][0]
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
5063
"prompt, image_url, success, re_content",
5164
[
5265
# test model is trained on CIFAR-10, but it's quite dumb due to small size
53-
("What is this:\n", IMG_URL_0, True, "(cat)+"),
54-
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
55-
("What is this:\n", IMG_URL_1, True, "(frog)+"),
56-
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
66+
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
67+
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
68+
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
69+
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
5770
("What is this:\n", "malformed", False, None),
5871
("What is this:\n", "https://google.com/404", False, None), # non-existent image
5972
("What is this:\n", "https://ggml.ai", False, None), # non-image data
@@ -62,17 +75,15 @@ def test_v1_models_supports_multimodal_capability():
6275
)
6376
def test_vision_chat_completion(prompt, image_url, success, re_content):
6477
global server
65-
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
66-
if image_url == "IMG_BASE64_URI_0":
67-
image_url = IMG_BASE64_URI_0
78+
server.start()
6879
res = server.make_request("POST", "/chat/completions", data={
6980
"temperature": 0.0,
7081
"top_k": 1,
7182
"messages": [
7283
{"role": "user", "content": [
7384
{"type": "text", "text": prompt},
7485
{"type": "image_url", "image_url": {
75-
"url": image_url,
86+
"url": get_img_url(image_url),
7687
}},
7788
]},
7889
],
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
90101
"prompt, image_data, success, re_content",
91102
[
92103
# test model is trained on CIFAR-10, but it's quite dumb due to small size
93-
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
94-
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
104+
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
105+
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
95106
("What is this: <__media__>\n", "malformed", False, None), # non-image data
96107
("What is this:\n", "", False, None), # empty string
97108
]
98109
)
99110
def test_vision_completion(prompt, image_data, success, re_content):
100111
global server
101-
server.start() # vision model may take longer to load due to download size
112+
server.start()
102113
res = server.make_request("POST", "/completions", data={
103114
"temperature": 0.0,
104115
"top_k": 1,
105-
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
116+
"prompt": {
117+
JSON_PROMPT_STRING_KEY: prompt,
118+
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
119+
},
106120
})
107121
if success:
108122
assert res.status_code == 200
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
116130
"prompt, image_data, success",
117131
[
118132
# test model is trained on CIFAR-10, but it's quite dumb due to small size
119-
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
120-
("What is this: <__media__>\n", IMG_BASE64_1, True),
133+
("What is this: <__media__>\n", "IMG_BASE64_0", True),
134+
("What is this: <__media__>\n", "IMG_BASE64_1", True),
121135
("What is this: <__media__>\n", "malformed", False), # non-image data
122136
("What is this:\n", "base64", False), # non-image data
123137
]
124138
)
125139
def test_vision_embeddings(prompt, image_data, success):
126140
global server
127-
server.server_embeddings=True
128-
server.n_batch=512
129-
server.start() # vision model may take longer to load due to download size
141+
server.server_embeddings = True
142+
server.n_batch = 512
143+
server.start()
144+
image_data = get_img_url(image_data)
130145
res = server.make_request("POST", "/embeddings", data={
131146
"content": [
132147
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },

0 commit comments

Comments
 (0)