|
8 | 8 | def create_server():
|
9 | 9 | global server
|
10 | 10 | server = ServerPreset.tinyllama2()
|
| 11 | + server.server_port = 8080 |
11 | 12 |
|
12 | 13 |
|
13 | 14 | @pytest.mark.parametrize(
|
@@ -351,3 +352,165 @@ def test_logprobs_stream():
|
351 | 352 | assert token.top_logprobs is not None
|
352 | 353 | assert len(token.top_logprobs) > 0
|
353 | 354 | assert aggregated_text == output_text
|
| 355 | + |
| 356 | + |
| 357 | +def test_progress_feature_enabled(): |
| 358 | + """Test progress feature when return_progress is enabled""" |
| 359 | + global server |
| 360 | + server.start() |
| 361 | + |
| 362 | + # Create a long prompt to ensure multiple batches are processed |
| 363 | + long_prompt = "This is a comprehensive test prompt designed to verify the progress functionality thoroughly. " * 100 |
| 364 | + |
| 365 | + res = server.make_stream_request("POST", "/chat/completions", data={ |
| 366 | + "max_tokens": 10, |
| 367 | + "messages": [ |
| 368 | + {"role": "user", "content": long_prompt}, |
| 369 | + ], |
| 370 | + "stream": True, |
| 371 | + "return_progress": True, |
| 372 | + }) |
| 373 | + |
| 374 | + progress_responses = [] |
| 375 | + content_responses = [] |
| 376 | + |
| 377 | + for data in res: |
| 378 | + choice = data["choices"][0] |
| 379 | + |
| 380 | + # Check for progress responses (they can be at root level or in delta) |
| 381 | + if "prompt_processing" in data: |
| 382 | + progress_responses.append(data["prompt_processing"]) |
| 383 | + elif "delta" in choice and "prompt_processing" in choice["delta"]: |
| 384 | + progress_responses.append(choice["delta"]["prompt_processing"]) |
| 385 | + elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]: |
| 386 | + content_responses.append(data) |
| 387 | + |
| 388 | + # Verify we received progress responses |
| 389 | + assert len(progress_responses) > 0, "No progress responses received" |
| 390 | + |
| 391 | + # Verify the last progress response shows 100% completion |
| 392 | + last_progress = progress_responses[-1] |
| 393 | + assert last_progress["progress"] >= 0.99, f"Progress did not reach 100% (last: {last_progress['progress']*100:.1f}%)" |
| 394 | + |
| 395 | + # Verify we received content responses |
| 396 | + assert len(content_responses) > 0, "No content responses received" |
| 397 | + |
| 398 | + |
| 399 | +def test_progress_feature_disabled(): |
| 400 | + """Test that progress is not sent when return_progress is disabled""" |
| 401 | + global server |
| 402 | + server.start() |
| 403 | + |
| 404 | + # Create a long prompt |
| 405 | + long_prompt = "This is a comprehensive test prompt designed to verify the progress functionality thoroughly. " * 100 |
| 406 | + |
| 407 | + res = server.make_stream_request("POST", "/chat/completions", data={ |
| 408 | + "max_tokens": 10, |
| 409 | + "messages": [ |
| 410 | + {"role": "user", "content": long_prompt}, |
| 411 | + ], |
| 412 | + "stream": True, |
| 413 | + "return_progress": False, # Disable progress |
| 414 | + }) |
| 415 | + |
| 416 | + progress_responses = [] |
| 417 | + content_responses = [] |
| 418 | + |
| 419 | + for data in res: |
| 420 | + choice = data["choices"][0] |
| 421 | + |
| 422 | + # Check for progress responses (they can be at root level or in delta) |
| 423 | + if "prompt_processing" in data: |
| 424 | + progress_responses.append(data["prompt_processing"]) |
| 425 | + elif "delta" in choice and "prompt_processing" in choice["delta"]: |
| 426 | + progress_responses.append(choice["delta"]["prompt_processing"]) |
| 427 | + elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]: |
| 428 | + content_responses.append(data) |
| 429 | + |
| 430 | + # Verify no progress responses were received |
| 431 | + assert len(progress_responses) == 0, f"Progress responses received when disabled: {len(progress_responses)}" |
| 432 | + |
| 433 | + # Verify we still received content responses |
| 434 | + assert len(content_responses) > 0, "No content responses received" |
| 435 | + |
| 436 | + |
| 437 | +def test_progress_feature_completion_endpoint(): |
| 438 | + """Test progress feature on /completion endpoint""" |
| 439 | + global server |
| 440 | + server.start() |
| 441 | + |
| 442 | + # Create a long prompt |
| 443 | + long_prompt = "This is a comprehensive test prompt designed to verify the progress functionality thoroughly. " * 100 |
| 444 | + |
| 445 | + res = server.make_stream_request("POST", "/completion", data={ |
| 446 | + "prompt": long_prompt, |
| 447 | + "stream": True, |
| 448 | + "return_progress": True, |
| 449 | + "max_tokens": 10, |
| 450 | + }) |
| 451 | + |
| 452 | + progress_responses = [] |
| 453 | + content_responses = [] |
| 454 | + |
| 455 | + for data in res: |
| 456 | + # Check for progress responses in /completion format |
| 457 | + if "prompt_processing" in data: |
| 458 | + progress_responses.append(data["prompt_processing"]) |
| 459 | + elif "content" in data and data["content"]: |
| 460 | + content_responses.append(data) |
| 461 | + |
| 462 | + # Verify we received progress responses |
| 463 | + assert len(progress_responses) > 0, "No progress responses received from /completion endpoint" |
| 464 | + |
| 465 | + # Verify the last progress response shows 100% completion |
| 466 | + last_progress = progress_responses[-1] |
| 467 | + assert last_progress["progress"] >= 0.99, f"Progress did not reach 100% (last: {last_progress['progress']*100:.1f}%)" |
| 468 | + |
| 469 | + # Verify we received content responses |
| 470 | + assert len(content_responses) > 0, "No content responses received from /completion endpoint" |
| 471 | + |
| 472 | + |
| 473 | +def test_progress_feature_with_different_batch_sizes(): |
| 474 | + """Test progress feature behavior with different batch processing scenarios""" |
| 475 | + global server |
| 476 | + server.start() |
| 477 | + |
| 478 | + # Test with different prompt lengths to simulate different batch processing |
| 479 | + test_cases = [ |
| 480 | + ("Short prompt", "Short test prompt"), |
| 481 | + ("Medium prompt", "This is a medium length test prompt designed to test progress functionality. " * 20), |
| 482 | + ("Long prompt", "This is a comprehensive test prompt designed to verify the progress functionality thoroughly. " * 100), |
| 483 | + ] |
| 484 | + |
| 485 | + for test_name, prompt in test_cases: |
| 486 | + res = server.make_stream_request("POST", "/chat/completions", data={ |
| 487 | + "max_tokens": 5, |
| 488 | + "messages": [ |
| 489 | + {"role": "user", "content": prompt}, |
| 490 | + ], |
| 491 | + "stream": True, |
| 492 | + "return_progress": True, |
| 493 | + }) |
| 494 | + |
| 495 | + progress_responses = [] |
| 496 | + content_responses = [] |
| 497 | + |
| 498 | + for data in res: |
| 499 | + choice = data["choices"][0] |
| 500 | + |
| 501 | + # Check for progress responses (they can be at root level or in delta) |
| 502 | + if "prompt_processing" in data: |
| 503 | + progress_responses.append(data["prompt_processing"]) |
| 504 | + elif "delta" in choice and "prompt_processing" in choice["delta"]: |
| 505 | + progress_responses.append(choice["delta"]["prompt_processing"]) |
| 506 | + elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]: |
| 507 | + content_responses.append(data) |
| 508 | + |
| 509 | + # Verify progress functionality works for all prompt lengths |
| 510 | + assert len(progress_responses) > 0, f"No progress responses for {test_name}" |
| 511 | + assert len(content_responses) > 0, f"No content responses for {test_name}" |
| 512 | + |
| 513 | + # Verify progress reaches 100% |
| 514 | + if progress_responses: |
| 515 | + last_progress = progress_responses[-1] |
| 516 | + assert last_progress["progress"] >= 0.99, f"Progress did not reach 100% for {test_name} (last: {last_progress['progress']*100:.1f}%)" |
0 commit comments