|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +import contextlib |
| 5 | +import inspect |
4 | 6 | import os
|
5 | 7 | import tempfile
|
6 | 8 | import textwrap
|
7 | 9 | import time
|
| 10 | +import uuid |
| 11 | +from collections import defaultdict |
| 12 | +from typing import Optional |
8 | 13 | from unittest.mock import patch
|
9 | 14 |
|
10 | 15 | import pytest
|
|
16 | 21 | KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
17 | 22 | NixlConnectorWorker)
|
18 | 23 | from vllm.forward_context import ForwardContext
|
19 |
| -from vllm.mocks.mock_nixl_connector import FakeNixlWrapper |
20 | 24 | from vllm.sampling_params import SamplingParams
|
21 | 25 |
|
22 | 26 | from .utils import create_request, create_scheduler, create_vllm_config
|
23 | 27 |
|
24 | 28 |
|
25 |
| -def _make_stub_pkg() -> str: |
26 |
| - """Return a directory that makes |
27 |
| - `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.""" |
28 |
| - td = tempfile.mkdtemp() |
29 |
| - pkg_root = os.path.join(td, "nixl", "_api") |
30 |
| - os.makedirs(pkg_root, exist_ok=True) |
| 29 | +class FakeNixlWrapper: |
| 30 | + """Mock implementation of NixlWrapper for testing. |
31 | 31 |
|
32 |
| - stub = textwrap.dedent("""\ |
33 |
| - # Forward the real FakeNixlWrapper that the driver already defined. |
34 |
| - print("In fake package") |
35 |
| - from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent |
36 |
| - """) |
37 |
| - with open(os.path.join(pkg_root, "__init__.py"), "w") as f: |
38 |
| - f.write(stub) |
| 32 | + We don't inherit from nixl._api.nixl_agent because nixl may not be |
| 33 | + installed. |
| 34 | + |
| 35 | + Note: The complete source of this class is also used in the |
| 36 | + `_make_fake_nixl_pkg` function to create a fake nixl package |
| 37 | + for Ray workers. |
| 38 | + """ |
| 39 | + |
| 40 | + AGENT_METADATA = b"fake_agent_metadata" |
| 41 | + REMOTE_AGENT_NAME = "remote_agent" |
| 42 | + |
| 43 | + def __init__(self, agent_name: str, *args, **kwargs): |
| 44 | + self._cycles_before_xfer_done = 0 |
| 45 | + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( |
| 46 | + lambda: 0) |
| 47 | + |
| 48 | + def get_reg_descs(self, caches_data, memory_type: str) -> list: |
| 49 | + return [str(uuid.uuid4()) for _ in caches_data] |
| 50 | + |
| 51 | + def register_memory(self, descs) -> None: |
| 52 | + pass |
| 53 | + |
| 54 | + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: |
| 55 | + return [str(uuid.uuid4()) for _ in blocks_data] |
| 56 | + |
| 57 | + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: |
| 58 | + return uuid.uuid4().int |
| 59 | + |
| 60 | + def get_agent_metadata(self) -> bytes: |
| 61 | + return self.AGENT_METADATA |
| 62 | + |
| 63 | + def add_remote_agent(self, agent_metadata: bytes) -> str: |
| 64 | + return self.REMOTE_AGENT_NAME |
| 65 | + |
| 66 | + def get_new_notifs(self) -> dict[str, list[bytes]]: |
| 67 | + # Used to collect done_sending, which we don't test yet. |
| 68 | + return {} |
| 69 | + |
| 70 | + def check_xfer_state(self, handle: int) -> str: |
| 71 | + if self._check_xfer_state_cycles[ |
| 72 | + handle] >= self._cycles_before_xfer_done: |
| 73 | + return "DONE" |
| 74 | + self._check_xfer_state_cycles[handle] += 1 |
| 75 | + return "PROC" |
| 76 | + |
| 77 | + def release_xfer_handle(self, handle: int) -> None: |
| 78 | + pass |
| 79 | + |
| 80 | + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: |
| 81 | + pass |
| 82 | + |
| 83 | + def make_prepped_xfer(self, |
| 84 | + xfer_type: str, |
| 85 | + local_xfer_side_handle: int, |
| 86 | + local_block_descs_ids: list[int], |
| 87 | + remote_xfer_side_handle: int, |
| 88 | + remote_block_descs_ids: list[int], |
| 89 | + notif_msg: Optional[bytes] = None) -> int: |
| 90 | + return uuid.uuid4().int |
39 | 91 |
|
40 |
| - # touch parent package |
41 |
| - open(os.path.join(td, "nixl", "__init__.py"), "w").close() |
42 |
| - return td |
| 92 | + def transfer(self, handle: int) -> str: |
| 93 | + return "PROC" |
| 94 | + |
| 95 | + ############################################################ |
| 96 | + # Follow are for changing the behavior during testing. |
| 97 | + ############################################################ |
| 98 | + |
| 99 | + def set_cycles_before_xfer_done(self, cycles: int): |
| 100 | + """Set the number of cycles before a transfer is considered done.""" |
| 101 | + self._cycles_before_xfer_done = cycles |
| 102 | + |
| 103 | + |
| 104 | +@contextlib.contextmanager |
| 105 | +def _make_fake_nixl_pkg(): |
| 106 | + """Context manager that creates a temporary package making |
| 107 | + `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. |
| 108 | + |
| 109 | + Automatically cleans up the temporary directory when done. |
| 110 | + """ |
| 111 | + with tempfile.TemporaryDirectory() as td: |
| 112 | + pkg_root = os.path.join(td, "nixl", "_api") |
| 113 | + os.makedirs(pkg_root, exist_ok=True) |
| 114 | + |
| 115 | + # Get the source code of FakeNixlWrapper class and dedent it |
| 116 | + fake_nixl_source = inspect.getsource(FakeNixlWrapper) |
| 117 | + fake_nixl_source = textwrap.dedent(fake_nixl_source) |
| 118 | + |
| 119 | + stub = f"""\ |
| 120 | +# Copy of FakeNixlWrapper implementation for Ray workers |
| 121 | +import uuid |
| 122 | +from collections import defaultdict |
| 123 | +from typing import Optional |
| 124 | +
|
| 125 | +{fake_nixl_source} |
| 126 | +
|
| 127 | +# Export as nixl_agent |
| 128 | +nixl_agent = FakeNixlWrapper |
| 129 | +""" |
| 130 | + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: |
| 131 | + f.write(stub) |
| 132 | + |
| 133 | + # touch parent package |
| 134 | + open(os.path.join(td, "nixl", "__init__.py"), "w").close() |
| 135 | + yield td |
43 | 136 |
|
44 | 137 |
|
45 | 138 | def test_basic_interface():
|
@@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
351 | 444 | kv_connector="NixlConnector",
|
352 | 445 | kv_role="kv_both",
|
353 | 446 | )
|
| 447 | + llm_kwargs = { |
| 448 | + "model": model_name, |
| 449 | + "enforce_eager": True, |
| 450 | + "gpu_memory_utilization": 0.5, |
| 451 | + "kv_transfer_config": kv_transfer_config, |
| 452 | + "distributed_executor_backend": distributed_executor_backend, |
| 453 | + } |
| 454 | + |
354 | 455 | timeout = 6
|
355 | 456 | monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
356 | 457 | monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
357 | 458 |
|
358 |
| - # Build runtime_env only if we’re using Ray |
| 459 | + # Build runtime_env only if we're using Ray |
359 | 460 | if distributed_executor_backend == "ray":
|
360 |
| - runtime_env = { |
361 |
| - "working_dir": _make_stub_pkg(), # ship stub package |
362 |
| - "env_vars": { |
363 |
| - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), |
364 |
| - }, |
365 |
| - } |
366 |
| - ray.init(runtime_env=runtime_env) |
367 |
| - |
368 |
| - llm = LLM( |
369 |
| - model=model_name, |
370 |
| - enforce_eager=True, |
371 |
| - gpu_memory_utilization=0.5, |
372 |
| - kv_transfer_config=kv_transfer_config, |
373 |
| - distributed_executor_backend=distributed_executor_backend, |
374 |
| - ) |
| 461 | + with _make_fake_nixl_pkg() as working_dir: |
| 462 | + runtime_env = { |
| 463 | + "working_dir": working_dir, # ship fake nixl package |
| 464 | + "env_vars": { |
| 465 | + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), |
| 466 | + }, |
| 467 | + } |
| 468 | + ray.init(runtime_env=runtime_env) |
| 469 | + |
| 470 | + _run_abort_timeout_test(llm_kwargs, timeout) |
| 471 | + else: |
| 472 | + _run_abort_timeout_test(llm_kwargs, timeout) |
| 473 | + |
| 474 | + |
| 475 | +def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): |
| 476 | + """Helper function to run the abort timeout test logic.""" |
| 477 | + llm = LLM(**llm_kwargs) |
375 | 478 | remote_prefill_opts = {
|
376 | 479 | "do_remote_decode": True,
|
377 | 480 | "do_remote_prefill": False,
|
|
0 commit comments