Skip to content

Commit 1cd40fa

Browse files
ruisearch42googlercolin
authored andcommitted
[P/D] Move FakeNixlWrapper to test dir (vllm-project#21328)
Signed-off-by: Rui Qiao <[email protected]>
1 parent 7c83c84 commit 1cd40fa

File tree

4 files changed

+140
-115
lines changed

4 files changed

+140
-115
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 136 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import contextlib
5+
import inspect
46
import os
57
import tempfile
68
import textwrap
79
import time
10+
import uuid
11+
from collections import defaultdict
12+
from typing import Optional
813
from unittest.mock import patch
914

1015
import pytest
@@ -16,30 +21,118 @@
1621
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
1722
NixlConnectorWorker)
1823
from vllm.forward_context import ForwardContext
19-
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
2024
from vllm.sampling_params import SamplingParams
2125

2226
from .utils import create_request, create_scheduler, create_vllm_config
2327

2428

25-
def _make_stub_pkg() -> str:
26-
"""Return a directory that makes
27-
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
28-
td = tempfile.mkdtemp()
29-
pkg_root = os.path.join(td, "nixl", "_api")
30-
os.makedirs(pkg_root, exist_ok=True)
29+
class FakeNixlWrapper:
30+
"""Mock implementation of NixlWrapper for testing.
3131
32-
stub = textwrap.dedent("""\
33-
# Forward the real FakeNixlWrapper that the driver already defined.
34-
print("In fake package")
35-
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
36-
""")
37-
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
38-
f.write(stub)
32+
We don't inherit from nixl._api.nixl_agent because nixl may not be
33+
installed.
34+
35+
Note: The complete source of this class is also used in the
36+
`_make_fake_nixl_pkg` function to create a fake nixl package
37+
for Ray workers.
38+
"""
39+
40+
AGENT_METADATA = b"fake_agent_metadata"
41+
REMOTE_AGENT_NAME = "remote_agent"
42+
43+
def __init__(self, agent_name: str, *args, **kwargs):
44+
self._cycles_before_xfer_done = 0
45+
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
46+
lambda: 0)
47+
48+
def get_reg_descs(self, caches_data, memory_type: str) -> list:
49+
return [str(uuid.uuid4()) for _ in caches_data]
50+
51+
def register_memory(self, descs) -> None:
52+
pass
53+
54+
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
55+
return [str(uuid.uuid4()) for _ in blocks_data]
56+
57+
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
58+
return uuid.uuid4().int
59+
60+
def get_agent_metadata(self) -> bytes:
61+
return self.AGENT_METADATA
62+
63+
def add_remote_agent(self, agent_metadata: bytes) -> str:
64+
return self.REMOTE_AGENT_NAME
65+
66+
def get_new_notifs(self) -> dict[str, list[bytes]]:
67+
# Used to collect done_sending, which we don't test yet.
68+
return {}
69+
70+
def check_xfer_state(self, handle: int) -> str:
71+
if self._check_xfer_state_cycles[
72+
handle] >= self._cycles_before_xfer_done:
73+
return "DONE"
74+
self._check_xfer_state_cycles[handle] += 1
75+
return "PROC"
76+
77+
def release_xfer_handle(self, handle: int) -> None:
78+
pass
79+
80+
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
81+
pass
82+
83+
def make_prepped_xfer(self,
84+
xfer_type: str,
85+
local_xfer_side_handle: int,
86+
local_block_descs_ids: list[int],
87+
remote_xfer_side_handle: int,
88+
remote_block_descs_ids: list[int],
89+
notif_msg: Optional[bytes] = None) -> int:
90+
return uuid.uuid4().int
3991

40-
# touch parent package
41-
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
42-
return td
92+
def transfer(self, handle: int) -> str:
93+
return "PROC"
94+
95+
############################################################
96+
# Follow are for changing the behavior during testing.
97+
############################################################
98+
99+
def set_cycles_before_xfer_done(self, cycles: int):
100+
"""Set the number of cycles before a transfer is considered done."""
101+
self._cycles_before_xfer_done = cycles
102+
103+
104+
@contextlib.contextmanager
105+
def _make_fake_nixl_pkg():
106+
"""Context manager that creates a temporary package making
107+
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
108+
109+
Automatically cleans up the temporary directory when done.
110+
"""
111+
with tempfile.TemporaryDirectory() as td:
112+
pkg_root = os.path.join(td, "nixl", "_api")
113+
os.makedirs(pkg_root, exist_ok=True)
114+
115+
# Get the source code of FakeNixlWrapper class and dedent it
116+
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
117+
fake_nixl_source = textwrap.dedent(fake_nixl_source)
118+
119+
stub = f"""\
120+
# Copy of FakeNixlWrapper implementation for Ray workers
121+
import uuid
122+
from collections import defaultdict
123+
from typing import Optional
124+
125+
{fake_nixl_source}
126+
127+
# Export as nixl_agent
128+
nixl_agent = FakeNixlWrapper
129+
"""
130+
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
131+
f.write(stub)
132+
133+
# touch parent package
134+
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
135+
yield td
43136

44137

45138
def test_basic_interface():
@@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
351444
kv_connector="NixlConnector",
352445
kv_role="kv_both",
353446
)
447+
llm_kwargs = {
448+
"model": model_name,
449+
"enforce_eager": True,
450+
"gpu_memory_utilization": 0.5,
451+
"kv_transfer_config": kv_transfer_config,
452+
"distributed_executor_backend": distributed_executor_backend,
453+
}
454+
354455
timeout = 6
355456
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
356457
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
357458

358-
# Build runtime_env only if were using Ray
459+
# Build runtime_env only if we're using Ray
359460
if distributed_executor_backend == "ray":
360-
runtime_env = {
361-
"working_dir": _make_stub_pkg(), # ship stub package
362-
"env_vars": {
363-
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
364-
},
365-
}
366-
ray.init(runtime_env=runtime_env)
367-
368-
llm = LLM(
369-
model=model_name,
370-
enforce_eager=True,
371-
gpu_memory_utilization=0.5,
372-
kv_transfer_config=kv_transfer_config,
373-
distributed_executor_backend=distributed_executor_backend,
374-
)
461+
with _make_fake_nixl_pkg() as working_dir:
462+
runtime_env = {
463+
"working_dir": working_dir, # ship fake nixl package
464+
"env_vars": {
465+
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
466+
},
467+
}
468+
ray.init(runtime_env=runtime_env)
469+
470+
_run_abort_timeout_test(llm_kwargs, timeout)
471+
else:
472+
_run_abort_timeout_test(llm_kwargs, timeout)
473+
474+
475+
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
476+
"""Helper function to run the abort timeout test logic."""
477+
llm = LLM(**llm_kwargs)
375478
remote_prefill_opts = {
376479
"do_remote_decode": True,
377480
"do_remote_prefill": False,

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ class KVOutputAggregator:
120120
output corresponding to Rank 0 for scheduler."""
121121

122122
def __init__(self, world_size: int):
123-
# Complete transfer tracker. Used by to track finished requests
124-
# [req_id -> n_finished_workers]
123+
# Complete transfer tracker. Used to track finished requests
124+
# [req_id -> n_remaining_workers]
125125
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
126126
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
127127

@@ -134,12 +134,10 @@ def update_finished_set(req_ids: Optional[set[str]],
134134
remaining_count_dict: dict[str, int],
135135
finished_set: set[str]) -> None:
136136
for req_id in req_ids or ():
137-
new_count = remaining_count_dict[req_id] - 1
138-
if new_count == 0:
137+
remaining_count_dict[req_id] -= 1
138+
if remaining_count_dict[req_id] == 0:
139139
finished_set.add(req_id)
140140
del remaining_count_dict[req_id]
141-
else:
142-
remaining_count_dict[req_id] = new_count
143141

144142
finished_sending = set[str]()
145143
finished_recving = set[str]()

vllm/mocks/__init__.py

Whitespace-only changes.

vllm/mocks/mock_nixl_connector.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)