From 3a70eda67f18d42e2f6c338f508f74660dafe585 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 1 May 2025 16:24:38 -0700 Subject: [PATCH 01/28] [V1] Support multiple kv connectors Signed-off-by: mgoin Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/factory.py | 5 + .../kv_transfer/kv_connector/v1/base.py | 7 +- .../kv_connector/v1/multi_connector.py | 95 +++++++++++++++++++ 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..715dc81662c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -105,3 +105,8 @@ def create_connector_v1( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..7bc91faf4e4c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,7 +22,6 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING import torch @@ -47,11 +46,15 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -@dataclass class KVConnectorMetadata: pass +class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], + KVConnectorMetadata): + pass + + class KVConnectorBase_V1(ABC): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 000000000000..813843e66a34 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from typing import TYPE_CHECKING + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + MultiKVConnectorMetadata) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class MultiConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + ktc = KVTransferConfig(**ktc) + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = ktc + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # We are overriding the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + for c, cm in zip(self._connectors, connector_metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + return max( + c.get_num_new_matched_tokens(request, num_computed_tokens) + for c in self._connectors) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + for c in self._connectors: + c.update_state_after_alloc(request, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + return MultiKVConnectorMetadata( + tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) From 89a57883f5ff65cafb06f50fa981df8b24e2cdd6 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 1 May 2025 23:29:03 +0000 Subject: [PATCH 02/28] Example script Signed-off-by: mgoin --- dual_storage.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 dual_storage.py diff --git a/dual_storage.py b/dual_storage.py new file mode 100644 index 000000000000..763346efdde7 --- /dev/null +++ b/dual_storage.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +context = "Hi " * 1000 +context2 = "Hey " * 500 +prompts = [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + +kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + { + "kv_connector": + "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": + { + "shared_storage_path": + "local_storage" + } + }, + { + "kv_connector": + "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": + { + "shared_storage_path": + "external_storage" + } + }, + ] + }) + +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=kv_transfer_config) + +# 1ST generation (prefill instance) +outputs = llm.generate( + prompts, + sampling_params, +) + +new_prompts = [] +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# Write new_prompts to output.txt +with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") +print(f"Saved {len(new_prompts)} prompts to output.txt") From 4fa62d42e3274128b4649230b28361d2e1d847bb Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 00:32:14 +0000 Subject: [PATCH 03/28] . Signed-off-by: mgoin --- dual_storage.py | 48 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/dual_storage.py b/dual_storage.py index 763346efdde7..ea5e244dfdce 100644 --- a/dual_storage.py +++ b/dual_storage.py @@ -14,32 +14,28 @@ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) -kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [ - { - "kv_connector": - "SharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": - { - "shared_storage_path": - "local_storage" - } - }, - { - "kv_connector": - "SharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": - { - "shared_storage_path": - "external_storage" - } - }, - ] - }) +kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + { + "kv_connector": "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": "local_storage" + } + }, + { + "kv_connector": "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": "external_storage" + } + }, + ] + }, +) llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, From 15ca542a47995154d68918d10d6161646c592e56 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 02:13:41 +0000 Subject: [PATCH 04/28] Add test Signed-off-by: mgoin --- .buildkite/test-pipeline.yaml | 1 + tests/v1/kv_transfer/test_multi_connector.py | 111 +++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 tests/v1/kv_transfer/test_multi_connector.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 84ee991f5659..ddadb9477623 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -209,6 +209,7 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_transfer - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_transfer/test_multi_connector.py new file mode 100644 index 000000000000..c83dd6f66633 --- /dev/null +++ b/tests/v1/kv_transfer/test_multi_connector.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +import filecmp +import shutil +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + +PROMPT_CONTEXT = "Hi " * 100 +PROMPTS = [ + PROMPT_CONTEXT + "Hello, my name is", + PROMPT_CONTEXT + "The capital of France is", +] + +SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) + + +# Helper function to compare directories recursively +def _compare_directories(dir1: Path, dir2: Path) -> bool: + """Compares two directories recursively for identical content.""" + dcmp = filecmp.dircmp(dir1, dir2) + if dcmp.left_only or dcmp.right_only or dcmp.diff_files: + print(f"Differences found between {dir1} and {dir2}:") + print(f" Left only: {dcmp.left_only}") + print(f" Right only: {dcmp.right_only}") + print(f" Different files: {dcmp.diff_files}") + return False + for sub_dir in dcmp.common_dirs: + if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): + return False + return True + + +def test_multi_shared_storage_connector_consistency(): + """ + Tests that MultiConnector with two SharedStorageConnectors saves + identical KV cache data to separate storage locations. + """ + storage_1_path = Path("storage_1/") + storage_2_path = Path("storage_2/") + shutil.rmtree(storage_1_path, ignore_errors=True) + shutil.rmtree(storage_2_path, ignore_errors=True) + storage_1_path.mkdir() + storage_2_path.mkdir() + + # Configure MultiConnector with two SharedStorageConnectors + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path) + } + }, { + "kv_connector": "SharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path) + } + }] + }, + ) + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + # Run generation - this should trigger saving KV cache + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + # --- Verification --- + + # Check that both storage directories were populated + local_subdirs = list(storage_1_path.iterdir()) + external_subdirs = list(storage_2_path.iterdir()) + + assert len( + local_subdirs + ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(external_subdirs) > 0, ( + f"External storage path {storage_2_path} is empty after generation.") + assert len(local_subdirs) == len(external_subdirs), ( + f"Mismatch in number of cache entries: " + f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + + # The subdirectories should correspond to the prompt hashes + # Since prompts are the same, the hash directories should be the same name + local_subdir_names = sorted([d.name for d in local_subdirs]) + external_subdir_names = sorted([d.name for d in external_subdirs]) + assert local_subdir_names == external_subdir_names, ( + "Cache directory names do not match between local and external storage" + ) + + # Compare the contents of each corresponding cache directory + for subdir_name in local_subdir_names: + print(f"Comparing contents of cache directory: {subdir_name}") + assert _compare_directories(storage_1_path / subdir_name, + storage_2_path / subdir_name), \ + (f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}") + + # Clean up + shutil.rmtree(storage_1_path) + shutil.rmtree(storage_2_path) From cd5af126a6ae660156c9de0006d5b74a98e2c088 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 02:30:43 +0000 Subject: [PATCH 05/28] make mypy happy Signed-off-by: mgoin --- vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 813843e66a34..e1d0f5c717e4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -41,6 +41,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # MultiKVConnectorMetadata. def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) for c, cm in zip(self._connectors, connector_metadata): c.bind_connector_metadata(cm) From 5a9a314dc11cb0710c418f8233f70240e9a62079 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 1 May 2025 23:00:24 -0700 Subject: [PATCH 06/28] move MultiKVConnectorMetadata to multi_connector.py Signed-off-by: Nick Hill --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 5 ----- .../kv_transfer/kv_connector/v1/multi_connector.py | 8 ++++++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7bc91faf4e4c..147d5339b850 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -50,11 +50,6 @@ class KVConnectorMetadata: pass -class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], - KVConnectorMetadata): - pass - - class KVConnectorBase_V1(ABC): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index e1d0f5c717e4..b44d032d59d4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -8,8 +8,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, - MultiKVConnectorMetadata) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -21,6 +20,11 @@ logger = init_logger(__name__) +class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], + KVConnectorMetadata): + pass + + class MultiConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): From 2f5e5378d473d5446e2ff83e002f820ae55d67c1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 1 May 2025 23:18:06 -0700 Subject: [PATCH 07/28] minor simplifications Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/multi_connector.py | 7 ++----- .../kv_connector/v1/shared_storage_connector.py | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index b44d032d59d4..2c002395f811 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -34,9 +34,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): "connectors") assert ktcs is not None for ktc in ktcs: - ktc = KVTransferConfig(**ktc) temp_config = copy.copy(vllm_config) - temp_config.kv_transfer_config = ktc + temp_config.kv_transfer_config = KVTransferConfig(**ktc) self._connectors.append( KVConnectorFactory.create_connector_v1(temp_config, role)) @@ -95,6 +94,4 @@ def build_connector_meta( self, scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: return MultiKVConnectorMetadata( - tuple( - c.build_connector_meta(scheduler_output) - for c in self._connectors)) + c.build_connector_meta(scheduler_output) for c in self._connectors) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..5200bb250285 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -132,8 +132,7 @@ def inject_kv_into_layer( dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, SharedStorageConnectorMetadata) if metadata is None: From 014cb2c532065ca05e5f5438a21aaffdc188ba8a Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 14:12:34 +0000 Subject: [PATCH 08/28] Remove script Signed-off-by: mgoin --- dual_storage.py | 62 ------------------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 dual_storage.py diff --git a/dual_storage.py b/dual_storage.py deleted file mode 100644 index ea5e244dfdce..000000000000 --- a/dual_storage.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig - -context = "Hi " * 1000 -context2 = "Hey " * 500 -prompts = [ - context + "Hello, my name is", - context + "The capital of France is", - context2 + "Your name is", - context2 + "The capital of China is", -] - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - -kv_transfer_config = KVTransferConfig( - kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [ - { - "kv_connector": "SharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": "local_storage" - } - }, - { - "kv_connector": "SharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": "external_storage" - } - }, - ] - }, -) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - kv_transfer_config=kv_transfer_config) - -# 1ST generation (prefill instance) -outputs = llm.generate( - prompts, - sampling_params, -) - -new_prompts = [] -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - new_prompts.append(prompt + generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write new_prompts to output.txt -with open("output.txt", "w") as f: - for prompt in new_prompts: - f.write(prompt + "\n") -print(f"Saved {len(new_prompts)} prompts to output.txt") From 7370d83c851c6bc9615ef6356e9ca0797e9fe855 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 2 May 2025 11:28:24 -0700 Subject: [PATCH 09/28] michael inprogress Signed-off-by: Nick Hill --- .../kv_connector/v1/multi_connector.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 2c002395f811..b8ec36f800d3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -38,8 +38,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): temp_config.kv_transfer_config = KVTransferConfig(**ktc) self._connectors.append( KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the connector that is assigned to it. + self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} - # We are overriding the base class method here because we need to bind + # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. def bind_connector_metadata( @@ -81,14 +84,24 @@ def get_num_new_matched_tokens( request: "Request", num_computed_tokens: int, ) -> int: - return max( - c.get_num_new_matched_tokens(request, num_computed_tokens) - for c in self._connectors) + for c in self._connectors: + toks = c.get_num_new_matched_tokens(request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if toks > 0: + self._requests_to_connector[request.req_id] = c + return toks + return 0 + def update_state_after_alloc(self, request: "Request", num_external_tokens: int): - for c in self._connectors: - c.update_state_after_alloc(request, num_external_tokens) + # If the request is not assigned to any connector, we do nothing. + if request.req_id not in self._requests_to_connector: + return + # We assume that the request is assigned to only one connector. + c = self._requests_to_connector[request.req_id] + c.update_state_after_alloc(request, num_external_tokens) def build_connector_meta( self, From df0a82bc8f39457f537d9ac840b38e3a06917d8d Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 21:28:42 +0000 Subject: [PATCH 10/28] Make sure we pop requests from connector dict Signed-off-by: mgoin --- .../kv_transfer/kv_connector/v1/multi_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index b8ec36f800d3..74b1454834fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -38,7 +38,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): temp_config.kv_transfer_config = KVTransferConfig(**ktc) self._connectors.append( KVConnectorFactory.create_connector_v1(temp_config, role)) - + # A mapping from request id to the connector that is assigned to it. self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} @@ -93,14 +93,13 @@ def get_num_new_matched_tokens( return toks return 0 - def update_state_after_alloc(self, request: "Request", num_external_tokens: int): # If the request is not assigned to any connector, we do nothing. if request.req_id not in self._requests_to_connector: return # We assume that the request is assigned to only one connector. - c = self._requests_to_connector[request.req_id] + c = self._requests_to_connector.pop(request.req_id) c.update_state_after_alloc(request, num_external_tokens) def build_connector_meta( From 842313507b1f1f178318dcdd4a990dc9f6242dd6 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 2 May 2025 21:45:48 +0000 Subject: [PATCH 11/28] req_id -> request_id Signed-off-by: mgoin --- .../kv_transfer/kv_connector/v1/multi_connector.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 74b1454834fc..e8857d6e3677 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -89,18 +89,19 @@ def get_num_new_matched_tokens( # The first connector that has new matched tokens will be assigned # to this request. if toks > 0: - self._requests_to_connector[request.req_id] = c + self._requests_to_connector[request.request_id] = c return toks return 0 def update_state_after_alloc(self, request: "Request", + block_ids: list[int], num_external_tokens: int): # If the request is not assigned to any connector, we do nothing. - if request.req_id not in self._requests_to_connector: + if request.request_id not in self._requests_to_connector: return # We assume that the request is assigned to only one connector. - c = self._requests_to_connector.pop(request.req_id) - c.update_state_after_alloc(request, num_external_tokens) + c = self._requests_to_connector.pop(request.request_id) + c.update_state_after_alloc(request, block_ids, num_external_tokens) def build_connector_meta( self, From 551fff18bb04013c100a89e8e2a2ebab2930d73e Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 6 May 2025 12:05:26 +0000 Subject: [PATCH 12/28] Update with better test Signed-off-by: mgoin --- tests/v1/kv_transfer/test_multi_connector.py | 138 ++++++++++++++++++- 1 file changed, 133 insertions(+), 5 deletions(-) diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_transfer/test_multi_connector.py index c83dd6f66633..5614b31e57ac 100644 --- a/tests/v1/kv_transfer/test_multi_connector.py +++ b/tests/v1/kv_transfer/test_multi_connector.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import filecmp import shutil +import tempfile +from collections import defaultdict from pathlib import Path from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -17,6 +23,50 @@ SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(name + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", + TestSharedStorageConnector.__module__, + TestSharedStorageConnector.__name__) + + # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -51,16 +101,18 @@ def test_multi_shared_storage_connector_consistency(): kv_role="kv_both", kv_connector_extra_config={ "connectors": [{ - "kv_connector": "SharedStorageConnector", + "kv_connector": "TestSharedStorageConnector", "kv_role": "kv_both", "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path) + "shared_storage_path": str(storage_1_path), + "name": "storage1", } }, { - "kv_connector": "SharedStorageConnector", + "kv_connector": "TestSharedStorageConnector", "kv_role": "kv_both", "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path) + "shared_storage_path": str(storage_2_path), + "name": "storage2", } }] }, @@ -106,6 +158,82 @@ def test_multi_shared_storage_connector_consistency(): (f"Contents differ for cache directory '{subdir_name}' between " f"{storage_1_path} and {storage_2_path}") + events = get_connector_events() + # get_num_new_matched_tokens will be called on each connector in turn. + # neither of them have hits so update_state_after_alloc won't be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will return new tokens from the first + # connector so update_state_after_alloc will be called once blocks + # are allocated for the first connector. + # get_num_new_matched_tokens *won't* be called on the second connector + # in this case. + assert events["storage1"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + assert events["storage2"][:2] == [ + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Delete storage1 connector state + shutil.rmtree(storage_1_path) + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will be called for the first connector but it + # won't have a hit so update_state_after_alloc won't be called. + # get_num_new_matched_tokens will also be called on the second connector, + # but it should have a hit so update_state_after_alloc will be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + # Clean up shutil.rmtree(storage_1_path) shutil.rmtree(storage_2_path) + + +def get_connector_events() -> dict[str, list[str]]: + # Read in connector events and reset the files. + import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") + connector_events = {} + for fname in event_files: + name = fname.split("connector_")[1].split("_events.log")[0] + try: + with open(fname, "r+") as f: + connector_events[name] = [ + line.strip() for line in f if line.strip() + ] + f.truncate(0) + except Exception as e: + print(f"[ERROR] Could not read connector events for {name}: {e}") + + return connector_events \ No newline at end of file From 10a26b6b96f05ec7182a653335370c802800cbc1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 6 May 2025 11:41:13 -0700 Subject: [PATCH 13/28] add comment to test Signed-off-by: Nick Hill --- tests/v1/kv_transfer/test_multi_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_transfer/test_multi_connector.py index 5614b31e57ac..64da0d79bf33 100644 --- a/tests/v1/kv_transfer/test_multi_connector.py +++ b/tests/v1/kv_transfer/test_multi_connector.py @@ -45,6 +45,8 @@ def __getattribute__(self, name): return object.__getattribute__(self, name) attr = getattr(self._connector, name) + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. if callable(attr): def wrapper(*args, **kwargs): @@ -236,4 +238,4 @@ def get_connector_events() -> dict[str, list[str]]: except Exception as e: print(f"[ERROR] Could not read connector events for {name}: {e}") - return connector_events \ No newline at end of file + return connector_events From 35f374841c0377d130021c23b7344d84e313c58d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 9 May 2025 12:17:34 -0700 Subject: [PATCH 14/28] Handle other new methods and latest API changes Signed-off-by: Nick Hill --- .../kv_connector/v1/multi_connector.py | 70 +++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index e8857d6e3677..4c845f0dcfe4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -15,6 +15,8 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.sampling_params import KVTransferParams + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -42,6 +44,15 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # A mapping from request id to the connector that is assigned to it. self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. @@ -76,6 +87,30 @@ def wait_for_save(self): for c in self._connectors: c.wait_for_save() + def get_finished(self) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_recving, finished_sending = set(), set() + for c in self._connectors: + recving, sending = c.get_finished() + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_recving or None, finished_sending or None + # ============================== # Scheduler-side methods # ============================== @@ -83,28 +118,49 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: for c in self._connectors: - toks = c.get_num_new_matched_tokens(request, num_computed_tokens) + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) # The first connector that has new matched tokens will be assigned # to this request. if toks > 0: self._requests_to_connector[request.request_id] = c - return toks - return 0 + return toks, load_async + return 0, False def update_state_after_alloc(self, request: "Request", - block_ids: list[int], + blocks: "KVCacheBlocks", num_external_tokens: int): # If the request is not assigned to any connector, we do nothing. if request.request_id not in self._requests_to_connector: return # We assume that the request is assigned to only one connector. c = self._requests_to_connector.pop(request.request_id) - c.update_state_after_alloc(request, block_ids, num_external_tokens) + c.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta( self, scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: return MultiKVConnectorMetadata( c.build_connector_meta(scheduler_output) for c in self._connectors) + + def request_finished( + self, + request: "Request", + blocks: "KVCacheBlocks", + ) -> tuple[bool, Optional["KVTransferParams"]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + raise RuntimeError( + "Only one connector can produce KVTransferParams") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + return async_saves > 0, kv_txfer_params From ff4e0834355be2a7ea43a2481d1efd6d5d250781 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 10 May 2025 14:44:48 -0700 Subject: [PATCH 15/28] update get_finished() method with new arg Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/multi_connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 4c845f0dcfe4..25ef6316c41b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -87,10 +87,12 @@ def wait_for_save(self): for c in self._connectors: c.wait_for_save() - def get_finished(self) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: finished_recving, finished_sending = set(), set() for c in self._connectors: - recving, sending = c.get_finished() + recving, sending = c.get_finished(finished_req_ids) if not recving and not sending: continue # Aggregate finished recving request ids. From fc65a1897a56151a00d41ad506bc44ed9a9442ac Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 12 May 2025 17:01:18 +0000 Subject: [PATCH 16/28] Move test to v1/kv_connector/unit Signed-off-by: mgoin --- .../v1/{kv_transfer => kv_connector/unit}/test_multi_connector.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/v1/{kv_transfer => kv_connector/unit}/test_multi_connector.py (100%) diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py similarity index 100% rename from tests/v1/kv_transfer/test_multi_connector.py rename to tests/v1/kv_connector/unit/test_multi_connector.py From e5e1191b89283e410da2ce01b1f0cff7f90b9134 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 May 2025 10:37:29 -0700 Subject: [PATCH 17/28] update KVTransferParams typing Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/multi_connector.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 25ef6316c41b..5a53a66ab683 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional import torch @@ -15,7 +15,6 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext - from vllm.sampling_params import KVTransferParams from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -151,7 +150,7 @@ def request_finished( self, request: "Request", blocks: "KVCacheBlocks", - ) -> tuple[bool, Optional["KVTransferParams"]]: + ) -> tuple[bool, Optional[dict[str, Any]]]: async_saves = 0 kv_txfer_params = None for c in self._connectors: @@ -160,8 +159,10 @@ def request_finished( async_saves += 1 if txfer_params is not None: if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. raise RuntimeError( - "Only one connector can produce KVTransferParams") + "Only one connector can produce KV transfer params") kv_txfer_params = txfer_params if async_saves > 1: self._extra_async_saves[request.request_id] = async_saves - 1 From 8537d934eac5e414806a82c14748695766cf47a3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 12 May 2025 17:57:05 +0000 Subject: [PATCH 18/28] Fix typing issue for get_finished Signed-off-by: mgoin --- .../distributed/kv_transfer/kv_connector/v1/multi_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 5a53a66ab683..509619c7ec60 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -89,7 +89,8 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: - finished_recving, finished_sending = set(), set() + finished_recving: set[str] = set() + finished_sending: set[str] = set() for c in self._connectors: recving, sending = c.get_finished(finished_req_ids) if not recving and not sending: From 9a767d1ac434aa3eb887e6315745637e59c4e1ba Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 13 May 2025 13:13:18 -0700 Subject: [PATCH 19/28] remove @dataclass from KVConnectorMetadata This was there originally but inadvertently dropped Signed-off-by: Nick Hill --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2ff61e8a400f..61e926b88f14 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,7 +22,6 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -65,7 +64,6 @@ def from_raw_dict( return None -@dataclass class KVConnectorMetadata: """ Abstract Metadata used to communicate between the From 9459f0a204fedd576f5c8ab87c52b17fc71ade95 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 6 May 2025 16:35:00 -0700 Subject: [PATCH 20/28] [P/D Disagg][Benchmarking] One request at a time benchmarking for P/D (#79) * Benchmark one concurrent req Signed-off-by: Tyler Michael Smith * Updates Signed-off-by: Tyler Michael Smith * restore Signed-off-by: Tyler Michael Smith * Improve random requests, switch up initial test Signed-off-by: Tyler Michael Smith --------- Signed-off-by: Tyler Michael Smith Signed-off-by: Nick Hill --- benchmarks/benchmark_one_concurrent_req.py | 382 ++++++++++++++++++ .../nixl_integration/run_accuracy_test.sh | 1 + .../nixl_integration/toy_proxy_server.py | 67 ++- 3 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 benchmarks/benchmark_one_concurrent_req.py diff --git a/benchmarks/benchmark_one_concurrent_req.py b/benchmarks/benchmark_one_concurrent_req.py new file mode 100644 index 000000000000..bfeed109663c --- /dev/null +++ b/benchmarks/benchmark_one_concurrent_req.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import asyncio +import logging +import random +import time +from dataclasses import dataclass +from typing import Optional + +import aiohttp # Import aiohttp +import numpy as np +from backend_request_func import RequestFuncInput, RequestFuncOutput +from benchmark_dataset import RandomDataset, SampleRequest +from tqdm import tqdm + +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +logger = logging.getLogger(__name__) + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +async def reset_cache(reset_url: str): + """Sends a POST request to reset the prefix cache.""" + logger.debug("Resetting prefix cache at %s", reset_url) + try: + async with (aiohttp.ClientSession() as session, session.post(reset_url) + as response): + response.raise_for_status( + ) # Raise an exception for bad status codes (4xx or 5xx) + logger.debug("Prefix cache reset successful: %s", response.status) + except aiohttp.ClientConnectorError as e: + logger.error("Failed to connect to cache reset endpoint %s: %s}", + reset_url, e) + except aiohttp.ClientResponseError as e: + logger.error("Cache reset request failed with status %s: %s", e.status, + e.message) + except Exception as e: + logger.error("An unexpected error occurred during cache reset: %s", e) + + +async def sequential_benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer, + input_requests: list[SampleRequest], + request_func, + selected_percentiles: list[float], + cache_reset_url: Optional[str] = None, +): + """ + Benchmark that processes requests sequentially, waiting for each to complete + before starting the next one. Resets prefix cache between requests. + """ + outputs = [] + + pbar = tqdm(total=len(input_requests)) + + # Small request to force a forward pass. + # Used for resetting the prefix cache. + dummy_req_input = RequestFuncInput( + model=model_id, + prompt="0", + api_url=api_url, + prompt_len=1, + output_len=1, + ) + + print("Starting initial single prompt test run...") + test_output = await request_func(request_func_input=dummy_req_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please check your configuration. " + "Error: %s", test_output.error) + else: + print("Initial test run completed. Starting sequential benchmark...") + + benchmark_start_time = time.perf_counter() + + # Process requests sequentially + for request in input_requests: + prompt, prompt_len, output_len = (request.prompt, request.prompt_len, + request.expected_output_len) + + logger.info("Sending request with len %s", request.prompt_len) + logger.debug("Request str: \"%s\"", request.prompt[:50]) + request_start_time = time.perf_counter() + + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + ) + + output = await request_func(request_func_input=request_func_input) + + request_end_time = time.perf_counter() + # Add timing information + if output.success and not hasattr(output, "latency"): + output.latency = request_end_time - request_start_time + logger.info("Finished request with latency %.4f s", output.latency) + + outputs.append(output) + pbar.update(1) + + # Reset prefix cache if configured, except after the very last request + if cache_reset_url: + await request_func(request_func_input=dummy_req_input) + await reset_cache(cache_reset_url) + + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + # Calculate metrics + metrics = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + ) + + print_results(metrics, benchmark_duration) + + result = { + "duration": + benchmark_duration, + "completed": + metrics.completed, + "total_input_tokens": + metrics.total_input, + "total_output_tokens": + metrics.total_output, + "input_lens": [request.prompt_len for request in input_requests], + "output_lens": + [output.output_tokens if output.success else 0 for output in outputs], + "ttfts": [output.ttft for output in outputs if output.success], + "itls": [output.itl for output in outputs if output.success], + "generated_texts": + [output.generated_text for output in outputs if output.success], + "errors": [output.error for output in outputs if not output.success], + } + + # Add summary statistics + for stat_name in ["ttft", "itl", "e2el"]: + for metric_name in ["mean", "median", "std"]: + result[f"{metric_name}_{stat_name}_ms"] = getattr( + metrics, f"{metric_name}_{stat_name}_ms") + + for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + result[f"p{p_word}_{stat_name}_ms"] = value + + return result + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer, + selected_percentiles: list[float], +) -> BenchmarkMetrics: + """Calculate benchmark metrics from results.""" + total_input = 0 + completed = 0 + total_output = 0 + ttfts = [] + itls = [] + e2els = [] + + for i, output in enumerate(outputs): + if output.success: + output_len = output.output_tokens + + if not output_len: + # Use tokenizer to count output tokens if not provided + output_len = len( + tokenizer(output.generated_text, + add_special_tokens=False).input_ids) + + total_output += output_len + total_input += input_requests[i].prompt_len + + if hasattr(output, "ttft") and output.ttft is not None: + ttfts.append(output.ttft) + + if hasattr(output, "itl") and output.itl: + # Ensure itl is a list of floats + if isinstance(output.itl, list): + itls.extend(output.itl) + else: + logger.warning( + "Expected list for ITL but got %s. Appending as is.", + type(output.itl)) + itls.append(output.itl) + + if hasattr(output, "latency") and output.latency is not None: + e2els.append(output.latency) + + completed += 1 + + return BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=total_output, + mean_ttft_ms=np.mean(ttfts or [0]) * 1000, + median_ttft_ms=np.median(ttfts or [0]) * 1000, + std_ttft_ms=np.std(ttfts or [0]) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or [0], p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or [0]) * 1000, + median_itl_ms=np.median(itls or [0]) * 1000, + std_itl_ms=np.std(itls or [0]) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or [0], p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or [0]) * 1000, + median_e2el_ms=np.median(e2els or [0]) * 1000, + std_e2el_ms=np.std(e2els or [0]) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or [0], p) * 1000) + for p in selected_percentiles], + ) + + +def print_results(metrics: BenchmarkMetrics, benchmark_duration: float): + """Print benchmark results in a formatted way.""" + print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="=")) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + + def print_metric_stats(metric_name, header): + print("{s:{c}^{n}}".format(s=header, n=60, c="-")) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_name.lower()}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_name.lower()}_ms"))) + + for p, value in getattr(metrics, + f"percentiles_{metric_name.lower()}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + + print_metric_stats("TTFT", "Time to First Token") + print_metric_stats("ITL", "Inter-token Latency") + print_metric_stats("E2EL", "End-to-end Latency") + print("=" * 60) + + +async def main_async(args): + # Import needed functions based on your setup + from backend_request_func import ASYNC_REQUEST_FUNCS + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + # Set up API URL + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + + # Set up Cache Reset URL + cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache" + logger.info("Prefix cache reset configured at: %s", cache_reset_url) + + # Get tokenizer + tokenizer = get_tokenizer(tokenizer_id, + trust_remote_code=args.trust_remote_code) + + # Get request function + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + input_requests = RandomDataset().sample( + tokenizer=tokenizer, + num_requests=args.num_requests, + prefix_len=0, + input_len=args.input_len, + output_len=args.output_len, + range_ratio=0.0, + ) + + # Run benchmark + result = await sequential_benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_func=request_func, + selected_percentiles=[50, 90, 95, 99], + cache_reset_url=cache_reset_url, + ) + + return result + + +def main(args): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + asyncio.run(main_async(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Sequential benchmark for LLM serving") + parser.add_argument("--backend", + type=str, + default="vllm", + help="Backend to use for requests") + parser.add_argument("--base-url", + type=str, + default=None, + help="Server base URL (overrides --host and --port)") + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--endpoint", + type=str, + default="/v1/completions", + help="API endpoint") + parser.add_argument("--model", + type=str, + required=True, + help="Name of the model") + parser.add_argument("--tokenizer", + type=str, + help="Name of the tokenizer (defaults to model name)") + parser.add_argument("--num-requests", + type=int, + default=100, + help="Number of requests to process") + parser.add_argument("--input-len", + type=int, + default=128, + help="Input len for generated prompts") + parser.add_argument("--output-len", + type=int, + default=None, + help="Override output len for requests") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--trust-remote-code", + action="store_true", + help="Trust remote code from HuggingFace") + + args = parser.parse_args() + main(args) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e90b72a7cf24..2b07e64c3c91 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -4,6 +4,7 @@ set -xe # Models to run MODELS=( "Qwen/Qwen3-0.6B" + "deepseek-ai/deepseek-vl2-tiny" ) # Number of prefill and decode instances to create diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 13071f581375..9eaffcffe08a 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import asyncio import itertools import os import uuid @@ -8,7 +9,7 @@ import httpx from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from vllm.logger import init_logger @@ -195,6 +196,70 @@ async def stream_service_response(client_info: dict, endpoint: str, yield chunk +async def _forward_reset_cache(client_session: httpx.AsyncClient, host: str, + port: int) -> dict: + target_url = f"http://{host}:{port}/reset_prefix_cache" + + try: + response: httpx.Response = await client_session.post(target_url, + timeout=5.0) + + return { + "status_code": response.status_code, + "error_type": None, + "error_message": None, + } + except Exception as e: + logger.error("Exception occurred sending POST to %s: %s - %s", + target_url, e.__class__.__name__, str(e)) + return { + "status_code": None, + "error_type": e.__class__.__name__, + "error_message": str(e), + } + + +@app.post("/reset_prefix_cache") +async def reset_prefix_cache_on_all_servers(request: Request): + """ + Forwards a reset_prefix_cache request to all prefill and decode servers. + """ + tasks = [] + + def add_reset_tasks_for_servers(server_list): + for server_info in server_list: + tasks.append( + _forward_reset_cache(server_info['client'], + server_info['host'], server_info['port'])) + + add_reset_tasks_for_servers(request.app.state.prefill_clients) + add_reset_tasks_for_servers(request.app.state.decode_clients) + + if not tasks: + return JSONResponse(content={ + "message": + "No prefill or decode servers configured to reset." + }, + status_code=200) + + all_results = await asyncio.gather(*tasks) + + num_prefill_servers = len(request.app.state.prefill_clients) + prefill_server_results = all_results[:num_prefill_servers] + decode_server_results = all_results[num_prefill_servers:] + + response_data = { + "message": + "Simple POST /reset_prefix_cache command forwarded to P/D workers.", + "prefill_servers_status": prefill_server_results, + "decode_servers_status": decode_server_results + } + all_downstream_ok = all( + result.get("error_type") is None for result in all_results) + status_code = 200 if all_downstream_ok else 207 # 207 Multi-Status + return JSONResponse(content=response_data, status_code=status_code) + + @app.post("/v1/completions") async def handle_completions(request: Request): try: From 3a63220173084262f121af6e9c4663ba261fdff3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 11 May 2025 09:58:28 -0700 Subject: [PATCH 21/28] add LMCache async save support Requires version of LMCache with the corresponding changes Signed-off-by: Nick Hill --- .../kv_connector/v1/lmcache_connector.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 2cb68dc1ff67..ebbf270e25a7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import torch from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl @@ -25,6 +25,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + self.async_save_supported = hasattr(self._lmcache_engine, + "get_finished") + # ============================== # Worker-side methods # ============================== @@ -86,6 +89,14 @@ def wait_for_save(self): """ self._lmcache_engine.wait_for_save() + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if not self.async_save_supported: + return None, None + + return self._lmcache_engine.get_finished(finished_req_ids) + # ============================== # Scheduler-side methods # ============================== @@ -104,8 +115,10 @@ def get_num_new_matched_tokens( computed tokens for this request Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if external KV cache tokens will be loaded + asynchronously (between scheduler steps). """ return self._lmcache_engine.get_num_new_matched_tokens( request, num_computed_tokens), False @@ -131,3 +144,20 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ return self._lmcache_engine.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + * True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + * Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return self.async_save_supported, None From 6b8a8f8d5ec59d5b5dbc8f6d23a6a7a54406d378 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 13 May 2025 09:55:56 -0700 Subject: [PATCH 22/28] handle disabled lmcache async save Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/lmcache_connector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index ebbf270e25a7..3886f94cd97a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -160,4 +160,7 @@ def request_finished( * Optional KVTransferParams to be included in the request outputs returned by the engine. """ - return self.async_save_supported, None + if not self.async_save_supported: + return False, None + + return self._lmcache_engine.request_finished(request, block_ids), None From f5bc7489e7d23f2dc867456f9b180255fa36c5f2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 13 May 2025 11:43:04 -0700 Subject: [PATCH 23/28] fix Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/lmcache_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 3886f94cd97a..eff7435c0a0a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -163,4 +163,4 @@ def request_finished( if not self.async_save_supported: return False, None - return self._lmcache_engine.request_finished(request, block_ids), None + return self._lmcache_engine.request_finished(request, block_ids) From ed1af74ef9cd6d2409ff6595233119eb03c58c03 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 14 May 2025 17:41:44 +0000 Subject: [PATCH 24/28] ifxed Signed-off-by: rshaw@neuralmagic.com --- benchmarks/benchmark_one_concurrent_req.py | 183 +++++++++++---------- 1 file changed, 94 insertions(+), 89 deletions(-) diff --git a/benchmarks/benchmark_one_concurrent_req.py b/benchmarks/benchmark_one_concurrent_req.py index bfeed109663c..ecfeb249bfe8 100644 --- a/benchmarks/benchmark_one_concurrent_req.py +++ b/benchmarks/benchmark_one_concurrent_req.py @@ -9,9 +9,10 @@ import aiohttp # Import aiohttp import numpy as np +from tqdm import tqdm + from backend_request_func import RequestFuncInput, RequestFuncOutput from benchmark_dataset import RandomDataset, SampleRequest -from tqdm import tqdm try: from vllm.transformers_utils.tokenizer import get_tokenizer @@ -44,17 +45,18 @@ async def reset_cache(reset_url: str): """Sends a POST request to reset the prefix cache.""" logger.debug("Resetting prefix cache at %s", reset_url) try: - async with (aiohttp.ClientSession() as session, session.post(reset_url) - as response): - response.raise_for_status( - ) # Raise an exception for bad status codes (4xx or 5xx) + async with ( + aiohttp.ClientSession() as session, + session.post(reset_url) as response, + ): + response.raise_for_status() # Raise an exception for bad status codes logger.debug("Prefix cache reset successful: %s", response.status) except aiohttp.ClientConnectorError as e: - logger.error("Failed to connect to cache reset endpoint %s: %s}", - reset_url, e) + logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e) except aiohttp.ClientResponseError as e: - logger.error("Cache reset request failed with status %s: %s", e.status, - e.message) + logger.error( + "Cache reset request failed with status %s: %s", e.status, e.message + ) except Exception as e: logger.error("An unexpected error occurred during cache reset: %s", e) @@ -91,8 +93,9 @@ async def sequential_benchmark( test_output = await request_func(request_func_input=dummy_req_input) if not test_output.success: raise ValueError( - "Initial test run failed - Please check your configuration. " - "Error: %s", test_output.error) + "Initial test run failed - Please check your configuration. Error: %s", + test_output.error, + ) else: print("Initial test run completed. Starting sequential benchmark...") @@ -100,11 +103,14 @@ async def sequential_benchmark( # Process requests sequentially for request in input_requests: - prompt, prompt_len, output_len = (request.prompt, request.prompt_len, - request.expected_output_len) + prompt, prompt_len, output_len = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + ) logger.info("Sending request with len %s", request.prompt_len) - logger.debug("Request str: \"%s\"", request.prompt[:50]) + logger.debug('Request str: "%s"', request.prompt[:50]) request_start_time = time.perf_counter() request_func_input = RequestFuncInput( @@ -147,21 +153,19 @@ async def sequential_benchmark( print_results(metrics, benchmark_duration) result = { - "duration": - benchmark_duration, - "completed": - metrics.completed, - "total_input_tokens": - metrics.total_input, - "total_output_tokens": - metrics.total_output, + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, "input_lens": [request.prompt_len for request in input_requests], - "output_lens": - [output.output_tokens if output.success else 0 for output in outputs], + "output_lens": [ + output.output_tokens if output.success else 0 for output in outputs + ], "ttfts": [output.ttft for output in outputs if output.success], "itls": [output.itl for output in outputs if output.success], - "generated_texts": - [output.generated_text for output in outputs if output.success], + "generated_texts": [ + output.generated_text for output in outputs if output.success + ], "errors": [output.error for output in outputs if not output.success], } @@ -169,7 +173,8 @@ async def sequential_benchmark( for stat_name in ["ttft", "itl", "e2el"]: for metric_name in ["mean", "median", "std"]: result[f"{metric_name}_{stat_name}_ms"] = getattr( - metrics, f"{metric_name}_{stat_name}_ms") + metrics, f"{metric_name}_{stat_name}_ms" + ) for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) @@ -200,8 +205,8 @@ def calculate_metrics( if not output_len: # Use tokenizer to count output tokens if not provided output_len = len( - tokenizer(output.generated_text, - add_special_tokens=False).input_ids) + tokenizer(output.generated_text, add_special_tokens=False).input_ids + ) total_output += output_len total_input += input_requests[i].prompt_len @@ -216,7 +221,8 @@ def calculate_metrics( else: logger.warning( "Expected list for ITL but got %s. Appending as is.", - type(output.itl)) + type(output.itl), + ) itls.append(output.itl) if hasattr(output, "latency") and output.latency is not None: @@ -231,18 +237,21 @@ def calculate_metrics( mean_ttft_ms=np.mean(ttfts or [0]) * 1000, median_ttft_ms=np.median(ttfts or [0]) * 1000, std_ttft_ms=np.std(ttfts or [0]) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or [0], p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or [0]) * 1000, median_itl_ms=np.median(itls or [0]) * 1000, std_itl_ms=np.std(itls or [0]) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or [0], p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or [0]) * 1000, median_e2el_ms=np.median(e2els or [0]) * 1000, std_e2el_ms=np.std(e2els or [0]) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or [0], p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles + ], ) @@ -250,26 +259,28 @@ def print_results(metrics: BenchmarkMetrics, benchmark_duration: float): """Print benchmark results in a formatted way.""" print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) def print_metric_stats(metric_name, header): print("{s:{c}^{n}}".format(s=header, n=60, c="-")) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_name.lower()}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_name.lower()}_ms"))) - - for p, value in getattr(metrics, - f"percentiles_{metric_name.lower()}_ms"): + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_name.lower()}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_name.lower()}_ms"), + ) + ) + + for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) print_metric_stats("TTFT", "Time to First Token") print_metric_stats("ITL", "Inter-token Latency") @@ -296,8 +307,7 @@ async def main_async(args): logger.info("Prefix cache reset configured at: %s", cache_reset_url) # Get tokenizer - tokenizer = get_tokenizer(tokenizer_id, - trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) # Get request function if backend in ASYNC_REQUEST_FUNCS: @@ -338,45 +348,40 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Sequential benchmark for LLM serving") - parser.add_argument("--backend", - type=str, - default="vllm", - help="Backend to use for requests") - parser.add_argument("--base-url", - type=str, - default=None, - help="Server base URL (overrides --host and --port)") + parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving") + parser.add_argument( + "--backend", type=str, default="vllm", help="Backend to use for requests" + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server base URL (overrides --host and --port)", + ) parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--endpoint", - type=str, - default="/v1/completions", - help="API endpoint") - parser.add_argument("--model", - type=str, - required=True, - help="Name of the model") - parser.add_argument("--tokenizer", - type=str, - help="Name of the tokenizer (defaults to model name)") - parser.add_argument("--num-requests", - type=int, - default=100, - help="Number of requests to process") - parser.add_argument("--input-len", - type=int, - default=128, - help="Input len for generated prompts") - parser.add_argument("--output-len", - type=int, - default=None, - help="Override output len for requests") + parser.add_argument( + "--endpoint", type=str, default="/v1/completions", help="API endpoint" + ) + parser.add_argument("--model", type=str, required=True, help="Name of the model") + parser.add_argument( + "--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)" + ) + parser.add_argument( + "--num-requests", type=int, default=100, help="Number of requests to process" + ) + parser.add_argument( + "--input-len", type=int, default=128, help="Input len for generated prompts" + ) + parser.add_argument( + "--output-len", type=int, default=None, help="Override output len for requests" + ) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--trust-remote-code", - action="store_true", - help="Trust remote code from HuggingFace") + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from HuggingFace", + ) args = parser.parse_args() main(args) From b6826ef90bd982199e160bc966d8279895df1e95 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 15 May 2025 09:12:14 -0700 Subject: [PATCH 25/28] [BugFix] Fix ordering of KVConnector finished send/rcv sets Signed-off-by: Nick Hill --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 3 ++- .../kv_transfer/kv_connector/v1/multi_connector.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9fdb5340f0e2..ef4460a592bd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -183,7 +183,8 @@ def get_finished( finished generating tokens. Returns: - ids of requests that have finished asynchronous (recving, sending). + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 509619c7ec60..bf54a093311a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -89,10 +89,10 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: - finished_recving: set[str] = set() finished_sending: set[str] = set() + finished_recving: set[str] = set() for c in self._connectors: - recving, sending = c.get_finished(finished_req_ids) + sending, recving = c.get_finished(finished_req_ids) if not recving and not sending: continue # Aggregate finished recving request ids. @@ -111,7 +111,7 @@ def get_finished( else: self._extra_async_saves[req_id] = extra_pending - 1 - return finished_recving or None, finished_sending or None + return finished_sending or None, finished_recving or None # ============================== # Scheduler-side methods From 3032553c2eeb9e63e020a7037ff4aef12daedf68 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 May 2025 07:52:26 -0700 Subject: [PATCH 26/28] [BugFix] Fix multi async save in MultiConnector (#90) The MultiKVConnector impl keeps track of cases where multiple connectors are async saving the same request, but this state needs to be shared from the scheduler side to the worker side. Signed-off-by: Nick Hill --- .../kv_connector/v1/multi_connector.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index bf54a093311a..cc61a6e99cc2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -21,9 +22,10 @@ logger = init_logger(__name__) -class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], - KVConnectorMetadata): - pass +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None class MultiConnector(KVConnectorBase_V1): @@ -46,6 +48,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # Keeps track of *additional* remaining async saves (beyond 1) to be # finished per request. Not needed for async loads since we only allow # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. self._extra_async_saves: dict[str, int] = {} def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): @@ -58,7 +61,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) - for c, cm in zip(self._connectors, connector_metadata): + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) def clear_connector_metadata(self) -> None: @@ -144,8 +150,13 @@ def update_state_after_alloc(self, request: "Request", def build_connector_meta( self, scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - return MultiKVConnectorMetadata( - c.build_connector_meta(scheduler_output) for c in self._connectors) + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata def request_finished( self, From 16bed574db1fd6272eaad1d593f80761bae49ee7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 May 2025 07:52:53 -0700 Subject: [PATCH 27/28] [BugFix] Fix handling of num_computed_tokens with connector (#91) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [BugFix] Fix handling of num_computed_tokens with connector https://github.com/vllm-project/vllm/pull/18001 changed the behaviour subtly and broke some multi-connector cases. This change ensures we don't call the connector get_num_new_matched_tokens method a second time for a given request after an async load has completed. Signed-off-by: Nick Hill * fix linting Signed-off-by: Nick Hill * handle full cache hit on P/D decode worker case Signed-off-by: Nick Hill * fix comment wording Co-authored-by: Nicolò Lucchesi --------- Signed-off-by: Nick Hill Co-authored-by: Nicolò Lucchesi --- .../kv_connector/v1/nixl_connector.py | 16 ++++++---- vllm/v1/core/sched/scheduler.py | 29 ++++++++++++------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index abd1ea2bea82..086dbeb90f34 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -208,7 +208,17 @@ def get_num_new_matched_tokens( rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) # No remote prefill for this request. return 0, False @@ -224,10 +234,6 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - # NOTE(rob): if prompt < block_size, no remote blocks - # since the remote only sends fully computed blocks, so - # skip recving for this request. num_external_tokens - # should be 0 if there are no remote blocks. if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f338e4ba1440..69afc066439b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + num_external_computed_tokens = 0 + load_kv_async = False + # Get already-cached tokens. if num_prealloc_computed_tokens == 0: new_computed_blocks, num_native_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_native_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 - # Get externally-cached tokens if using a KVConnector. - num_external_computed_tokens, load_kv_async = ( - (0, False) if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + - num_external_computed_tokens + - num_prealloc_computed_tokens) + # Total computed tokens (allocated in prior step). + num_computed_tokens = num_prealloc_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget # P/D: loading remote KV, do not allocate for new work. if load_kv_async: + assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. else: @@ -411,7 +417,8 @@ def schedule(self) -> SchedulerOutput: # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, From db9a82dd8b7b85af3d0a6956bae92e1a5d986413 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sat, 17 May 2025 23:12:26 -0400 Subject: [PATCH 28/28] [Bugfix][P/D] Fix Preemption + Prefix Cache Bug (#92) * updated Signed-off-by: rshaw@neuralmagic.com * cleanup issue Signed-off-by: rshaw@neuralmagic.com --------- Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- vllm/v1/core/sched/scheduler.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 086dbeb90f34..e7f1f2a3a6b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -690,7 +690,8 @@ def _read_blocks( # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - self.nixl_wrapper.send_notif(dst_engine_id, + agent_name = self._remote_agents[dst_engine_id] + self.nixl_wrapper.send_notif(agent_name, notif_msg=request_id.encode("utf-8")) return diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 69afc066439b..a9d85e534115 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -411,6 +411,13 @@ def schedule(self) -> SchedulerOutput: delay_cache_blocks=load_kv_async, ) if new_blocks is None: + # P/D: if the request is recved on this step, + # then we need to free the kv cache blocks + if num_prealloc_computed_tokens > 0: + assert request.num_computed_tokens != 0 + self.kv_cache_manager.free(request) + request.num_computed_tokens = 0 + # The request cannot be scheduled. break