From 7de18d541b0da661685d481d7306cbe5e9f7960b Mon Sep 17 00:00:00 2001 From: TJian Date: Mon, 12 May 2025 00:14:30 +0800 Subject: [PATCH 01/24] [BUG] [ROCm] [MLA] Fix variable name bug due to change in variable name in PR #17483 (#17961) Signed-off-by: tjtanaa --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index f46010d757af..3abb185c5b8f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -98,17 +98,17 @@ def _get_paged_kv_tensors( paged_kv_last_page_len, ) - def _build_decode(self, block_table: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: ( paged_kv_indices, paged_kv_indptr, paged_last_page_len, - ) = self._get_paged_kv_tensors(block_table, seq_lens) + ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) attn_metadata = AiterMLADecodeMetadata( - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, From 021c16c7caaa6886248f1d048edbcdb678415964 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 12 May 2025 08:56:30 +0800 Subject: [PATCH 02/24] [Model] Broadcast Ovis2 implementation to fit Ovis1.6 (#17861) Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.md | 6 +- examples/offline_inference/vision_language.py | 21 +- .../vision_language_multi_image.py | 22 +- tests/conftest.py | 8 +- .../multimodal/generation/test_common.py | 27 +- .../generation/vlm_utils/model_utils.py | 17 +- .../multimodal/processing/test_common.py | 5 +- tests/models/registry.py | 6 +- vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/aimv2.py | 127 +-------- .../models/{ovis2.py => ovis.py} | 240 +++++++++++++++--- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/configs/__init__.py | 2 +- .../configs/{ovis2.py => ovis.py} | 13 + .../transformers_utils/processors/__init__.py | 2 +- .../processors/{ovis2.py => ovis.py} | 42 ++- 16 files changed, 330 insertions(+), 212 deletions(-) rename vllm/model_executor/models/{ovis2.py => ovis.py} (59%) rename vllm/transformers_utils/configs/{ovis2.py => ovis.py} (93%) rename vllm/transformers_utils/processors/{ovis2.py => ovis.py} (94%) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 8c6e7b04de85..48fc24f3447a 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -1045,10 +1045,10 @@ Specified using `--task generate`. * * ✅︎ * ✅︎ -- * `Ovis2ForConditionalGeneration`^ - * Ovis2 +- * `Ovis` + * Ovis2, Ovis1.6 * T + I+ - * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc. + * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. * * * ✅︎ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5c173ab1abb9..c54f328c7a38 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: ) -# Ovis2 -def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: +# Ovis +def run_ovis(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "AIDC-AI/Ovis2-1B" @@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="half", - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, limit_mm_per_prompt={modality: 1}, ) - placeholder = "\n" - prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions] + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [[{ + 'role': 'user', + 'content': f"\n{question}" + }] for question in questions] + prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) return ModelRequestData( engine_args=engine_args, @@ -1069,7 +1072,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, - "ovis2": run_ovis2, + "ovis": run_ovis, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 48d590b05b06..20a8e635e322 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ) -# Ovis2 -def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: +# Ovis +def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "AIDC-AI/Ovis2-1B" engine_args = EngineArgs( @@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: trust_remote_code=True, dtype="half", limit_mm_per_prompt={"image": len(image_urls)}, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, ) - placeholder = '\n'.join( - [f'Image {i+1}: ' for i in range(len(image_urls))]) + '\n' - prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) return ModelRequestData( engine_args=engine_args, @@ -713,7 +715,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, - "ovis2": load_ovis2, + "ovis": load_ovis, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, diff --git a/tests/conftest.py b/tests/conftest.py index fa979f1093be..c5700179c228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -355,10 +355,16 @@ def __init__( **model_kwargs, ) + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + if (getattr(model, "quantization_method", None) != "bitsandbytes" and len({p.device for p in model.parameters()}) < 2): - model = model.to(self.device) + model = model.to(device=self.device) self.model = model diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 6e915a9f6005..dead2edc4fa3 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -476,6 +476,31 @@ max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), + "ovis1_6-gemma2": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Gemma2-9B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), + "ovis1_6": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Llama3.2-3B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -486,7 +511,7 @@ dtype="half", # use sdpa mode for hf runner since ovis2 didn't work with flash_attn hf_model_kwargs={"llm_attn_implementation": "sdpa"}, - patch_hf_runner=model_utils.ovis2_patch_hf_runner, + patch_hf_runner=model_utils.ovis_patch_hf_runner, ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index f0f4ed989241..e31408d6063f 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -678,12 +678,8 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): return hf_model -def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.visual_tokenizer.to(hf_model.dtype) - hf_model.model.vte.to(hf_model.dtype) - hf_model.model.llm.to(hf_model.dtype) - hf_model.model.get_output_embeddings = lambda: \ hf_model.model.llm.get_output_embeddings() @@ -691,7 +687,16 @@ def processor(*args, text="", images=None, **kwargs): text_tokenizer = hf_model.model.get_text_tokenizer() images = [images] if isinstance(images, Image) else images - text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0] + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("user\n", "\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( text_or_conversations=text, images=images) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 772a2db3e48a..e6b70a4438e9 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -146,7 +146,8 @@ def _test_processing_correctness_hf( batch_idx: int, ignore_mm_keys: Optional[set[str]] = None, ): - if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): + if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox", + "whisper"): # For some multimodal models, tokenizer will always add bos_token # at the beginning of prompt by default, causing hf_processor outputs # incorrect token ids. So we need use `add_special_tokens=False` here @@ -274,6 +275,8 @@ def _test_processing_correctness_mistral( "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "AIDC-AI/Ovis1.6-Gemma2-9B", + "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", diff --git a/tests/models/registry.py b/tests/models/registry.py index a1f2edac02b9..683d15d508ec 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -355,9 +355,9 @@ def check_available_online( max_transformers_version="4.48", transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B", - trust_remote_code=True, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501 + "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 38fe98572178..db43b2dd295d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -512,7 +512,7 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", - "internvl_chat", "ovis2", "skywork_chat", + "internvl_chat", "ovis", "skywork_chat", "NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"): return "" if model_type in ("mllama", "llama4"): diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 730e770dc3d6..aefd6c973755 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -5,129 +5,14 @@ from typing import Optional import torch -from torch import nn, softmax +import torch.nn as nn from torch.nn import functional as F -from torch.nn.functional import gumbel_softmax, pad from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.transformers_utils.configs.ovis2 import (AIMv2Config, - Aimv2VisualTokenizerConfig) - -IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, - -305] # kept for vocab prefixed tokens - - -def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax - index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like( - y_soft, memory_format=torch.legacy_contiguous_format).scatter_( - dim, index, 1.0) - ret = y_hard - y_soft.detach() + y_soft - return ret - - -class Aimv2VisualTokenizer(torch.nn.Module): - - def __init__(self, - config: Aimv2VisualTokenizerConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - **kwargs): - super().__init__() - self.config = config - self.backbone = AIMv2Model( - config=config.backbone_config, # noqa - quant_config=quant_config, - prefix=f"{prefix}.visual_tokenizer") - # reserved tokens for IMAGE_INDICATORS - head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) - self.head = torch.nn.Sequential( - ReplicatedLinear( - config.backbone_config.hidden_size * config.hidden_stride * - config.hidden_stride, - head_dim, - bias=False, - ), torch.nn.LayerNorm(head_dim)) - - @property - def dtype(self): - return self.backbone.dtype - - @property - def device(self): - return self.backbone.device - - def tokenize(self, logits): - if self.config.tokenize_function == 'softmax': - tokens = softmax(logits, dim=-1) - elif self.config.tokenize_function == 'gumbel_argmax': - tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) - elif self.config.tokenize_function == 'st_argmax': - tokens = st_argmax(logits, dim=-1) - else: - raise ValueError( - 'Invalid `max_type`, expected softmax or gumbel_argmax ' - f'or st_argmax, but got {self.config.tokenize_function}') - return tokens - - def encode(self, pixel_values): - features = self.backbone(pixel_values) - if self.config.drop_cls_token: - features = features[:, 1:, :] - - # merge number of `hidden_stride * hidden_stride` hidden states together - # to reduce token sequence length - # e.g., for hidden_stride=2, this leads to a token length reduction: - # 1024 -> 256 for aimv2 - if self.config.hidden_stride > 1: - # this `d` maybe different from the above `d`` - n, L, d = features.shape - sqrt_l = int(L**0.5) - assert sqrt_l**2 == L, ( - "The token sequence length should be a perfect square.") - features = features.reshape(n, sqrt_l, sqrt_l, d) - pl = (self.config.hidden_stride - - (sqrt_l % - self.config.hidden_stride)) % self.config.hidden_stride - features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) - sqrt_l += pl - features = features.reshape(n, sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, - sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, d) - # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] - features = features.permute(0, 1, 3, 2, 4, 5) - # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] - features = features.flatten(3) - # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] - features = features.reshape( - n, -1, - self.config.hidden_stride * self.config.hidden_stride * d) - - return features - - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" - features = self.encode(pixel_values) - logits, _ = self.head[0]( - features) # we spllit the sequncial here for not throwing an error - logits = self.head[1](logits) - tokens = self.tokenize(logits) - # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with - # [BatchSize, #Token, 5], after which, tokens' shape should become - # [BatchSize, #Token, VocabSize] - batch_size, token_len, _ = tokens.shape - padding_tensor = torch.zeros(size=(batch_size, token_len, - len(IMAGE_INDICATOR_IDS)), - dtype=tokens.dtype, - device=tokens.device, - layout=tokens.layout, - requires_grad=False) - tokens = torch.cat((tokens, padding_tensor), dim=2) - return tokens +from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): @@ -302,14 +187,6 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.trunk") - @property - def dtype(self): - return self.trunk.blocks[0].attn.qkv.weight.dtype - - @property - def device(self): - return self.trunk.blocks[0].attn.qkv.device - def forward( self, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/ovis2.py b/vllm/model_executor/models/ovis.py similarity index 59% rename from vllm/model_executor/models/ovis2.py rename to vllm/model_executor/models/ovis.py index 67cc86e7fc82..5204c751216f 100644 --- a/vllm/model_executor/models/ovis2.py +++ b/vllm/model_executor/models/ovis.py @@ -15,17 +15,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Ovis2 model.""" +""" PyTorch Ovis model.""" +import math from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn from torch import Tensor -from transformers import BatchFeature +from torch.nn.functional import gumbel_softmax, pad, softmax +from transformers import BaseImageProcessor, BatchFeature from vllm.config import VllmConfig -from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.aimv2 import AIMv2Model +from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -38,19 +44,160 @@ BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.ovis2 import OvisConfig -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig, + OvisConfig) +from vllm.transformers_utils.processors.ovis import OvisProcessor from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "" -IMAGE_PAD_TOKEN_ID = 151655 -NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256 +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, +} -class Ovis2ImagePatchInputs(TypedDict): + +def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax + index = y_soft.argmax(dim, keepdim=True) + return torch.zeros_like( + y_soft, + memory_format=torch.legacy_contiguous_format, + ).scatter_(dim, index, 1.0) + + +class VisualTokenizer(torch.nn.Module): + + def __init__( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.backbone = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.backbone", + ) + # reserved tokens for IMAGE_INDICATORS + head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + config.backbone_config.hidden_size * config.hidden_stride * + config.hidden_stride, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + model_type = config.backbone_config.model_type + if model_type == "aimv2": + return AIMv2Model( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + elif model_type == "siglip_vision_model": + return SiglipVisionModel( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self): + return next(self.head.parameters()).dtype + + @property + def device(self): + return next(self.head.parameters()).device + + def tokenize(self, logits): + if self.config.tokenize_function == 'softmax': + tokens = softmax(logits, dim=-1) + elif self.config.tokenize_function == 'gumbel_argmax': + tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) + elif self.config.tokenize_function == 'st_argmax': + tokens = st_argmax(logits, dim=-1) + else: + raise ValueError( + 'Invalid `max_type`, expected softmax or gumbel_argmax ' + f'or st_argmax, but got {self.config.tokenize_function}') + return tokens + + def encode(self, pixel_values): + features = self.backbone(pixel_values) + if self.config.drop_cls_token: + features = features[:, 1:, :] + + # merge number of `hidden_stride * hidden_stride` hidden states together + # to reduce token sequence length + # e.g., for hidden_stride=2, this leads to a token length reduction: + # 1024 -> 256 for aimv2 + if self.config.hidden_stride > 1: + # this `d` maybe different from the above `d`` + n, L, d = features.shape + sqrt_l = int(L**0.5) + assert sqrt_l**2 == L, ( + "The token sequence length should be a perfect square.") + features = features.reshape(n, sqrt_l, sqrt_l, d) + pl = (self.config.hidden_stride - + (sqrt_l % + self.config.hidden_stride)) % self.config.hidden_stride + features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) + sqrt_l += pl + features = features.reshape(n, sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, d) + # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] + features = features.permute(0, 1, 3, 2, 4, 5) + # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] + features = features.flatten(3) + # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] + features = features.reshape( + n, -1, + self.config.hidden_stride * self.config.hidden_stride * d) + + return features + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" + features = self.encode(pixel_values) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with + # [BatchSize, #Token, 5], after which, tokens' shape should become + # [BatchSize, #Token, VocabSize] + tokens = torch.nn.functional.pad( + tokens, + (0, len(IMAGE_INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class OvisImagePatchInputs(TypedDict): type: Literal["image_patches"] flat_data: torch.Tensor """ @@ -92,31 +239,50 @@ def dtype(self): return self.weight.dtype -class Ovis2ProcessingInfo(BaseProcessingInfo): +class OvisProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(OvisConfig) def get_hf_processor(self, **kwargs): - return self.ctx.get_hf_processor(OvisProcessor) + return self.ctx.get_hf_processor( + OvisProcessor, + image_pad_token=self.get_image_pad_token(), + image_segment_len=self.get_image_segment_len(), + ) - def get_image_processor(self) -> OvisProcessor: + def get_image_segment_len(self) -> int: + visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config + image_size = visual_tokenizer_config.backbone_config.image_size + patch_size = visual_tokenizer_config.backbone_config.patch_size + hidden_stride = visual_tokenizer_config.hidden_stride + patch_grid_length = math.ceil(image_size / patch_size) + assert patch_grid_length % hidden_stride == 0, ( + f"patch_grid_length {patch_grid_length} is not divisible by " + f"hidden_stride {hidden_stride}") + # minus 1 for presented image token + return (patch_grid_length // hidden_stride)**2 - 1 + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: return self.get_hf_processor().image_processor # type: ignore def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return { # 32k is model token limit at the moment - "image": - self.get_hf_config().multimodal_max_length // - ((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT) - } + return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: - image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2, - height=image_processor.size['shortest_edge'] * 9 * 2) + height, width = self.get_hf_processor().get_image_size() + hs = self.get_hf_config().visual_tokenizer_config.hidden_stride + # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code + # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96 + return ImageSize(width=width * hs * 9, height=height * hs * 9) -class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]): +class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -141,7 +307,7 @@ def get_dummy_mm_data( return mm_data -class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): +class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): def image_indicators_to_visual_tokens( self, @@ -165,9 +331,9 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: - # # Avoid warning from HF logger for text-only input - prompt_ids = self.info.get_tokenizer().encode(prompt) - # prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") processed_outputs = super()._call_hf_processor( @@ -226,10 +392,10 @@ def get_replacement_ovis(item_idx): ] -@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor, - info=Ovis2ProcessingInfo, - dummy_inputs=Ovis2DummyInputsBuilder) -class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder) +class Ovis(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -242,24 +408,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "llm"), ) - self.visual_tokenizer = Aimv2VisualTokenizer( + self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", - image_processor_name_or_path=config.visual_tokenizer_config. - backbone_config.name_or_path, ) self.vte = VisualEmbedding( self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size) + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + # TODO(Isotr0py): PP support # self.make_empty_intermediate_tensors = ( # self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]: + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -275,7 +442,7 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") - return Ovis2ImagePatchInputs( + return OvisImagePatchInputs( type="image_patches", flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), patches_per_image=[ @@ -288,7 +455,7 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: image_patches_flat = image_input["flat_data"] patches_per_image = image_input["patches_per_image"] indicator_tokens = image_input["indicator_tokens"] @@ -338,7 +505,7 @@ def get_input_embeddings( if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, - [IMAGE_PAD_TOKEN_ID]) + self.image_pad_token_id) return inputs_embeds def forward( @@ -375,8 +542,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.logits_processor(self.llm.lm_head, hidden_states, - sampling_metadata) + logits = self.llm.compute_logits(hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index aef4566193c8..c5414e129dd1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -195,7 +195,7 @@ "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), - "Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"), + "Ovis": ("ovis", "Ovis"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index db3efafeef96..ed10c22c84f0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config -from vllm.transformers_utils.configs.ovis2 import OvisConfig +from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config diff --git a/vllm/transformers_utils/configs/ovis2.py b/vllm/transformers_utils/configs/ovis.py similarity index 93% rename from vllm/transformers_utils/configs/ovis2.py rename to vllm/transformers_utils/configs/ovis.py index 437a16e778c2..0ec224214f06 100644 --- a/vllm/transformers_utils/configs/ovis2.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -123,6 +123,19 @@ def __init__(self, **kwargs): self.backbone_kwargs['num_hidden_layers'] = self.depths[0] +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 2e9cf3e4d90b..2bd9ab1f099b 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -2,6 +2,6 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.processors.ovis import OvisProcessor __all__ = ["DeepseekVLV2Processor", "OvisProcessor"] diff --git a/vllm/transformers_utils/processors/ovis2.py b/vllm/transformers_utils/processors/ovis.py similarity index 94% rename from vllm/transformers_utils/processors/ovis2.py rename to vllm/transformers_utils/processors/ovis.py index a633256ec12c..48e786792cf5 100644 --- a/vllm/transformers_utils/processors/ovis2.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -22,6 +22,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import List, Union import PIL @@ -32,7 +33,7 @@ Unpack) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -__all__ = [ 'OvisProcessor'] +__all__ = ['OvisProcessor'] IGNORE_ID = -100 class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] @@ -64,18 +65,29 @@ class OvisProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] + valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"] image_processor_class = "AutoImageProcessor" - tokenizer_class = "Qwen2Tokenizer" + tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs): + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + image_segment_len=255, + **kwargs, + ): self.image_token = "" - self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token + self.image_pad_token = image_pad_token + self.image_segment_len = image_segment_len super().__init__(image_processor, tokenizer, chat_template=chat_template) - self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] - self.extra_special_tokens = { + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { "image_token": -200, "image_atom": -300, "image_start": -301, @@ -83,8 +95,9 @@ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, ima "image_col_sep": -303, "image_row_sep": -304, "image_end": -305, - 'image_pad': self.image_pad_token_id, + 'image_pad': image_pad_token_id, } + return extra_special_tokens def __call__( self, @@ -224,8 +237,14 @@ def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor: return torch.tensor(batch_token_ids, dtype=torch.long) def get_image_size(self): - height = self.image_processor.crop_size["height"] - width = self.image_processor.crop_size["width"] + size = self.image_processor.size + if 'shortest_edge' in size: + width = height = size['shortest_edge'] + elif "height" in size and "width" in size: + width = size['width'] + height = size['height'] + else: + raise ValueError( "Can't parse image size from image_processor config.") return height, width def get_token_value(self, tok): @@ -259,8 +278,7 @@ def construct_image_placeholders(self, grid): for token in image_placeholders: padded_placeholder_tokens.append(image_padding_token_id) if token == image_atom_token_id: - # Add 255 padding tokens after each image atom token - padded_placeholder_tokens.extend([image_padding_token_id] * 255) + padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len) return padded_placeholder_tokens def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors): From d45fe333fb8d3ab73d73a6458e3cde73f14f0d7e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 12 May 2025 09:02:39 +0800 Subject: [PATCH 03/24] [misc] add instructions on how to install nvshmem/pplx/deepep (#17964) Signed-off-by: youkaichao --- tools/ep_kernels/README.md | 27 +++++++ tools/ep_kernels/install_python_libraries.sh | 77 ++++++++++++++++++++ tools/ep_kernels/install_system_drivers.sh | 24 ++++++ tools/ep_kernels/install_system_libraries.sh | 18 +++++ 4 files changed, 146 insertions(+) create mode 100644 tools/ep_kernels/README.md create mode 100644 tools/ep_kernels/install_python_libraries.sh create mode 100644 tools/ep_kernels/install_system_drivers.sh create mode 100644 tools/ep_kernels/install_system_libraries.sh diff --git a/tools/ep_kernels/README.md b/tools/ep_kernels/README.md new file mode 100644 index 000000000000..5c98e999da33 --- /dev/null +++ b/tools/ep_kernels/README.md @@ -0,0 +1,27 @@ +Large-scale cluster-level expert parallel, as described in the [DeepSeek-V3 Technical Report](http://arxiv.org/abs/2412.19437), is an efficient way to deploy sparse MoE models with many experts. However, such deployment requires many components beyond a normal Python package, including system package support and system driver support. It is impossible to bundle all these components into a Python package. + +Here we break down the requirements in 3 steps: +1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this. +2. Build and install the system libraries (GDR Copy). This step requires root access. You can do it inside a Docker container so that they can be shipped as a single image. +3. Build and install the system drivers (GDR Copy, and necessary modifications to NVIDIA driver to enable IBGDA). This step requires root access, and must be done on the host machine. + +2 and 3 are necessary for multi-node deployment. + +All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`. + +# Usage + +## Single-node + +```bash +bash install_python_libraries.sh +``` + +## Multi-node + +```bash +bash install_python_libraries.sh +sudo bash install_system_libraries.sh +sudo bash install_system_drivers.sh +sudo reboot # Reboot is required to load the new driver +``` diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh new file mode 100644 index 000000000000..e5632f4b5875 --- /dev/null +++ b/tools/ep_kernels/install_python_libraries.sh @@ -0,0 +1,77 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# install dependencies if not installed +pip3 install cmake torch ninja + +# build gdrcopy, required by nvshmem +pushd $WORKSPACE +wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.4.4.tar.gz +mkdir -p gdrcopy_src +tar -xvf v2.4.4.tar.gz -C gdrcopy_src --strip-components=1 +pushd gdrcopy_src +make -j$(nproc) +make prefix=$WORKSPACE/gdrcopy_install install +popd + +# build nvshmem +pushd $WORKSPACE +mkdir -p nvshmem_src +wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz +tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1 +pushd nvshmem_src +wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch +git init +git apply -vvv nvshmem.patch + +# assume CUDA_HOME is set correctly +export GDRCOPY_HOME=$WORKSPACE/gdrcopy_install +export NVSHMEM_SHMEM_SUPPORT=0 +export NVSHMEM_UCX_SUPPORT=0 +export NVSHMEM_USE_NCCL=0 +export NVSHMEM_IBGDA_SUPPORT=1 +export NVSHMEM_PMIX_SUPPORT=0 +export NVSHMEM_TIMEOUT_DEVICE_POLLING=0 +export NVSHMEM_USE_GDRCOPY=1 +export NVSHMEM_IBRC_SUPPORT=1 + +# remove MPI dependency +export NVSHMEM_BUILD_TESTS=0 +export NVSHMEM_BUILD_EXAMPLES=0 +export NVSHMEM_MPI_SUPPORT=0 + +cmake -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install + +cd $WORKSPACE/nvshmem_build/ +make -j$(nproc) +make install + +popd + +export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH + +# build and install pplx, require pytorch installed +pushd $WORKSPACE +git clone https://github.com/ppl-ai/pplx-kernels +cd pplx-kernels +# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 +# PIP_NO_BUILD_ISOLATION=0 disables build isolation +PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install -vvv -e . +popd + +# build and install deepep, require pytorch installed +pushd $WORKSPACE +git clone https://github.com/deepseek-ai/DeepEP +cd DeepEP +export NVSHMEM_DIR=$WORKSPACE/nvshmem_install +PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . +popd diff --git a/tools/ep_kernels/install_system_drivers.sh b/tools/ep_kernels/install_system_drivers.sh new file mode 100644 index 000000000000..8b0669ef404f --- /dev/null +++ b/tools/ep_kernels/install_system_drivers.sh @@ -0,0 +1,24 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# build and install gdrcopy driver +pushd $WORKSPACE +cd gdrcopy_src +./insmod.sh +# run gdrcopy_copybw to test the installation +$WORKSPACE/gdrcopy_install/bin/gdrcopy_copybw + +# turn on IBGDA +echo 'options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"' | tee -a /etc/modprobe.d/nvidia.conf +update-initramfs -u + +echo "Please reboot the system to apply the changes" diff --git a/tools/ep_kernels/install_system_libraries.sh b/tools/ep_kernels/install_system_libraries.sh new file mode 100644 index 000000000000..c148d5443900 --- /dev/null +++ b/tools/ep_kernels/install_system_libraries.sh @@ -0,0 +1,18 @@ +set -ex + +# prepare workspace directory +WORKSPACE=$1 +if [ -z "$WORKSPACE" ]; then + export WORKSPACE=$(pwd)/ep_kernels_workspace +fi + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + +# build and install gdrcopy system packages +pushd $WORKSPACE +cd gdrcopy_src/packages +apt install devscripts -y +CUDA=${CUDA_HOME:-/usr/local/cuda} ./build-deb-packages.sh +dpkg -i *.deb From 08bf7840780980c7568c573c70a6a8db94fd45ff Mon Sep 17 00:00:00 2001 From: Cheng Kuan Yong Jason Date: Mon, 12 May 2025 09:06:10 +0800 Subject: [PATCH 04/24] [Bugfix] validate grammar and throw 400 error instead of crashing the engine when xgrammar validation fails (#17623) Signed-off-by: Jason Cheng Co-authored-by: Russell Bryant --- .../openai/test_chat_completion.py | 137 ++++++++++++++++++ .../v1/entrypoints/openai/test_completion.py | 94 ++++++++++++ vllm/v1/engine/processor.py | 4 +- vllm/v1/structured_output/backend_xgrammar.py | 6 + 4 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 tests/v1/entrypoints/openai/test_chat_completion.py diff --git a/tests/v1/entrypoints/openai/test_chat_completion.py b/tests/v1/entrypoints/openai/test_chat_completion.py new file mode 100644 index 000000000000..c650ccd0ccd7 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_chat_completion.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +# any model with a chat template defined in tokenizer_config should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_json_schema(client: openai.AsyncOpenAI, + model_name: str) -> None: + invalid_json_schema = { + "$defs": { + "CarType": { + "enum": ["sedan", "SUV", "Truck", "Coupe"], + "title": "CarType", + "type": "string", + } + }, + "properties": { + "brand": { + "title": "Brand", + "type": "string" + }, + "model": { + "title": "Model", + "type": "string" + }, + "car_type": { + "$ref": "#/$defs/CarType" + }, + "foo": "bar", + }, + "required": ["brand", "model", "car_type"], + "title": "CarDescription", + "type": "object", + } + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": invalid_json_schema}, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"[.*", + "stop": ["\n"] + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): + invalid_simplified_sql_grammar = """ + root ::= select_statementinvalidsyntax + + select_statement ::= "SELECT " column " from " table " where " condition + + column ::= "col_1 " | "col_2 " + + table ::= "table_1 " | "table_2 " + + condition ::= column "= " number + + number ::= "1 " | "2 " + """ + + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 57ca99e1f68c..3ffc54f520b4 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -584,3 +584,97 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_json_schema(client: openai.AsyncOpenAI, + model_name: str) -> None: + invalid_json_schema = { + "$defs": { + "CarType": { + "enum": ["sedan", "SUV", "Truck", "Coupe"], + "title": "CarType", + "type": "string", + } + }, + "properties": { + "brand": { + "title": "Brand", + "type": "string" + }, + "model": { + "title": "Model", + "type": "string" + }, + "car_type": { + "$ref": "#/$defs/CarType" + }, + "foo": "bar", + }, + "required": ["brand", "model", "car_type"], + "title": "CarDescription", + "type": "object", + } + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={"guided_json": invalid_json_schema}, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={ + "guided_regex": r"[.*", + "stop": ["\n"] + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): + invalid_simplified_sql_grammar = """ + root ::= select_statementinvalidsyntax + + select_statement ::= "SELECT " column " from " table " where " condition + + column ::= "col_1 " | "col_2 " + + table ::= "table_1 " | "table_2 " + + condition ::= column "= " number + + number ::= "1 " | "2 " + """ + + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") + with pytest.raises((openai.BadRequestError, openai.APIError)): + await client.completions.create( + model=model_name, + prompt=prompt, + extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + ) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 2aa19f8bbb57..66be88738535 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -188,8 +188,10 @@ def _validate_structured_output(self, params: SamplingParams) -> None: validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" except ValueError: - # The request includes some jsonschema feature(s) that + # The request either failed validation + # or includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. + validate_guidance_grammar(params, tokenizer=None) params.guided_decoding.backend = "guidance" # Remember that this backend was set automatically params.guided_decoding.backend_was_auto = True diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index c82a3cab2fa3..baa478bc63bd 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -282,6 +282,12 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: else: schema = gd_params.json + try: + xgr.Grammar.from_json_schema(schema) + except Exception as err: + raise ValueError("Failed to transform json schema into a grammar: " + f"{err}") from err + if has_xgrammar_unsupported_json_features(schema): raise ValueError("The provided JSON schema contains features not " "supported by xgrammar.") From ada50aa2952fd0a7c645d75c9db472030131ddc7 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 12 May 2025 12:58:02 +0800 Subject: [PATCH 05/24] [bugfix] fix the wrong parser (#17958) Signed-off-by: reidliu41 Co-authored-by: reidliu41 --- vllm/entrypoints/cli/collect_env.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index d5f9f7e729f0..810ecfdf71c3 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -4,12 +4,11 @@ from vllm.collect_env import main as collect_env_main from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser class CollectEnvSubcommand(CLISubcommand): - """The `serve` subcommand for the vLLM CLI. """ + """The `collect-env` subcommand for the vLLM CLI. """ def __init__(self): self.name = "collect-env" @@ -23,12 +22,12 @@ def cmd(args: argparse.Namespace) -> None: def subparser_init( self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - serve_parser = subparsers.add_parser( + collect_env_parser = subparsers.add_parser( "collect-env", help="Start collecting environment information.", description="Start collecting environment information.", usage="vllm collect-env") - return make_arg_parser(serve_parser) + return collect_env_parser def cmd_init() -> list[CLISubcommand]: From 19a3c78d1ff8c6cfab078227bad1c8cb79f47887 Mon Sep 17 00:00:00 2001 From: Li Wang Date: Mon, 12 May 2025 12:58:23 +0800 Subject: [PATCH 06/24] [Bugfix] Fix pydantic.errors.PydanticUserError (#17962) Signed-off-by: wangli --- vllm/entrypoints/openai/serving_engine.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 37134cfb3da3..f1d907f519c5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,18 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 import json +import sys import time from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, Sequence) from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypedDict, TypeVar, Union) + TypeVar, Union) from fastapi import Request from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient From 430783018cbfad69c6ff3a52479bf7b556b65247 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sun, 11 May 2025 21:58:33 -0700 Subject: [PATCH 07/24] [Bugfix][TPU] Use np array when updating cache slot_mapping (#17971) Signed-off-by: Siyuan Liu --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 983f8707a245..687dabee2290 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -531,7 +531,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): np.add(block_numbers * self.block_size, block_offsets, out=self.input_batch.block_table. - slot_mapping_cpu[:total_num_scheduled_tokens]) + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 From 891b9d33de7ee7b3ee95b9bd7bb8a9cffae0e08c Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Mon, 12 May 2025 01:55:53 -0400 Subject: [PATCH 08/24] [Fix] Benchmark `"EngineClient" has no attribute "model_config"` (#17976) Signed-off-by: Brayden Zhong --- benchmarks/benchmark_throughput.py | 7 ++++--- vllm/benchmarks/throughput.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1f65277e1bfe..cd6c76ad6096 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -146,9 +146,10 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: + model_config = await llm.get_model_config() assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) + model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) for request in requests), ( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests.") @@ -599,7 +600,7 @@ def validate_args(args): "--lora-path", type=str, default=None, - help="Path to the lora adapters to use. This can be an absolute path, " + help="Path to the LoRA adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.") parser.add_argument( "--prefix-len", diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index b3e24911cc98..13110a8b4db3 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -148,9 +148,10 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: + model_config = await llm.get_model_config() assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) + model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) for request in requests), ( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests.") From 3a5ea751292664265bdd0dd22da86d725e457816 Mon Sep 17 00:00:00 2001 From: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Date: Mon, 12 May 2025 15:45:21 +0800 Subject: [PATCH 09/24] [Feature] Support DeepSeekV3 Function Call (#17784) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 许文卿 Signed-off-by: Xu Wenqing --- docs/source/features/tool_calling.md | 7 + examples/tool_chat_template_deepseekv3.jinja | 96 +++++ .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/deepseekv3_tool_parser.py | 368 ++++++++++++++++++ 4 files changed, 473 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_deepseekv3.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index f3b808b3d2b7..2795b769345e 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -236,6 +236,13 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup Flags: `--tool-call-parser hermes` +### DeepSeek-V3 Models (`deepseek_v3`) + +Supported models: +* `deepseek-ai/DeepSeek-V3-0324` + +Flags: `--tool-call-parser deepseek_v3 --chat-template examples/tool_chat_template_deepseekv3.jinja` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. diff --git a/examples/tool_chat_template_deepseekv3.jinja b/examples/tool_chat_template_deepseekv3.jinja new file mode 100644 index 000000000000..36f3781439ed --- /dev/null +++ b/examples/tool_chat_template_deepseekv3.jinja @@ -0,0 +1,96 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} + +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{{ bos_token }} +{{ ns.system_prompt }} +{%- if tools %} + {{"\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{"\n\n\n"}} + + {{"For function call returns, you should first print <|tool▁calls▁begin|>"}} + + {{"For each function call, you should return object like:\n" }} + {{"<|tool▁call▁begin|>function<|tool▁sep|>\n```json\n\n```<|tool▁call▁end|>"}} + + {{"At the end of function call returns, you should print <|tool▁calls▁end|><|end▁of▁sentence|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content'] + '<|Assistant|>'}} + {%- endif %} + + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} + {%- endif %} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- set ns.is_output_first = true %} + + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {% set content = message['content'] %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first %} + {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} + {%- endif %} +{%- endfor -%} + +{% if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} +{% endif %} + +{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} +{% endif %} diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index b81dc4e7ad7b..f7c7112b124f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from .abstract_tool_parser import ToolParser, ToolParserManager +from .deepseekv3_tool_parser import DeepSeekV3ToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser @@ -15,5 +16,5 @@ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser", "Phi4MiniJsonToolParser" + "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py new file mode 100644 index 000000000000..bd8e87e4cee8 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Sequence +from typing import Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v3") +class DeepSeekV3ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*)\n```<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*[^\n`])" + ) + + self.stream_tool_call_name_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)\n") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + tool_type, function_name, function_args = match + tool_calls.append( + ToolCall( + type=tool_type, + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_type, tool_name, tool_args = ( + current_tool_call_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_type, tool_name = ( + current_tool_call_name_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. From 9fbf2bfbd509845fc37139e1ec51f60e41af0815 Mon Sep 17 00:00:00 2001 From: Aaruni Aggarwal <47731267+AaruniAggarwal@users.noreply.github.com> Date: Mon, 12 May 2025 13:41:55 +0530 Subject: [PATCH 10/24] Correcting testcases in builkite job for IBM Power (#17675) Signed-off-by: Aaruni Aggarwal --- .buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 5d863dd82e9b..077bd9914907 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -32,9 +32,12 @@ function cpu_tests() { set -e pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] - pytest -v -s tests/models/encoder_decoder/language -m cpu_model" + pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] + pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]" } # All of CPU tests are expected to be finished less than 40 mins. From a658de3c9f371cd696b747a8fc0706ab7b0f026a Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 6 May 2025 08:08:38 +0000 Subject: [PATCH 11/24] tp_size in metadata and handshake with rank0 first 2-handshake model with vertical kv cache split Signed-off-by: nicklucche --- vllm/config.py | 1 + .../kv_connector/v1/nixl_connector.py | 196 +++++++++++++----- 2 files changed, 142 insertions(+), 55 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c6b97bbdcd66..4fe287fa320a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3472,6 +3472,7 @@ class KVTransferConfig: kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) """any extra config that the connector may need.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d26184982270..595bfd96dc37 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,7 +14,7 @@ from typing_extensions import Optional from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) from vllm.distributed.parallel_state import ( @@ -103,6 +103,8 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int + block_len: int @dataclass @@ -153,7 +155,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + self.connector_worker = NixlConnectorWorker(str(self.engine_id), vllm_config.kv_transfer_config) ############################################################ # Scheduler Side Methods @@ -347,7 +349,7 @@ def request_finished( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str): + def __init__(self, engine_id: str, kv_config: KVTransferConfig): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -356,8 +358,8 @@ def __init__(self, engine_id: str): # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # Metadata. self.engine_id = engine_id @@ -365,23 +367,27 @@ def __init__(self, engine_id: str): self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() + # Remote tracking ds only contain one entry for own tp group: engine_id-self.rank # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, list[int]] = {} + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # nixl_prepped_dlist_handle (int). - self.src_xfer_side_handle: int = 0 + # nixl_prepped_dlist_handle. Different dst TP sizes require preparing + # xfer layout differently. + self.src_xfer_side_handle: int = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[str, int] = dict() - # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, int] = {} + # Map of engine_id -> num_blocks. Remote TP ranks will have the same + # number of blocks. + self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] # In progress transfers. @@ -399,6 +405,8 @@ def __init__(self, engine_id: str): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + + self._tp_size = {self.engine_id: self.world_size} @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, @@ -411,6 +419,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # move this into the scheduler rather than worker, since # each rank needs the metadata of all other ranks (whereas # in this setup, each rank only gets one other rank's meta. + # TODO iterate over all ranks to handshake with M. Can we get M from config? encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -423,6 +432,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # NOTE(rob): we need each rank to have a unique port. This # hack to keeps us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. + # TODO get rank port util port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank path = f"tcp://{host}:{port}" logger.debug("Starting listening on path: %s", path) @@ -442,9 +452,8 @@ def _nixl_handshake(self, host: str, port: int): # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = f"tcp://{host}:{port + self.rank}" - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: + + def handshake(sock, rank: int)->NixlAgentMetadata: # Send query for the request. sock.send(GET_META_MSG) metadata_bytes = sock.recv() @@ -453,13 +462,33 @@ def _nixl_handshake(self, host: str, port: int): got_metadata_time = time.perf_counter() # Register Remote agent. - self.add_remote_agent(metadata) + self.add_remote_agent(metadata, rank) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) + got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = f"tcp://{host}:{port}" + logger.debug("Querying master rank metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + metadata = handshake(sock, 0) + + # TODO should we skip this if remote world_size == world_size (homogeneous)? + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + p_remote_rank = self.rank % metadata.tp_size + if p_remote_rank > 0: + path = f"tcp://{host}:{port + p_remote_rank}" + logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) + with zmq_ctx(zmq.REQ, path) as sock: + metadata = handshake(sock, p_remote_rank) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -473,6 +502,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] + # TODO does this include tp dependent size? block_shape = first_kv_cache.shape[-block_rank:] else: # [2 (k and v), num_blocks, ...] @@ -483,6 +513,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc self.block_len = kv_elem_size * math.prod(block_shape) + print(f"\n\n{self.block_len=}\n\n") logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, first_kv_cache.shape) @@ -510,6 +541,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) + print("************************BLOCKS SETUP") + print(f"Number of blocks {len(kv_caches_base_addr)=}\n") + print(f"{self.num_blocks=}, {self.block_len=}, {self.num_regions=}\n") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -524,6 +558,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, + tp_size=self.world_size, + block_len=self.block_len ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -534,49 +570,93 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): + # FIXME one other approach I tried is loading half of every remote block instead of half the blocks. Doesnt seem to make much difference engine_id = nixl_agent_meta.engine_id - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr + + # TODO enforce tp sizes are exact multiples + d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" + dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker # Create src descs and xfer side handles. + if d_workers_per_p_worker not in self.src_xfer_side_handle: + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(dst_num_blocks_per_local_rank): + block_offset = block_id * nixl_agent_meta.block_len + # (addr, len, device id) + # use the block size of the dst/P node to make sure regions match + blocks_data.append( + (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # Create dst descs and xfer side handles. TP workers have same #blocks + # if engine_id in self.dst_num_blocks: + # assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + + # self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # When D_TP>P_TP, P blocks are split between D workers. Hence we may + # record a fraction of the total num_blocks in P. + self.dst_num_blocks[engine_id] = dst_num_blocks_per_local_rank + blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) - - # Create dst descs and xfer side handles. - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + # With heterogenous TP, prepare the descriptors by splitting the P KV + # cache into chunks of D worker's size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[KV_0 | KV_1] (contiguous view). + p_remote_rank = self.rank % nixl_agent_meta.tp_size + # Only register the remote's descriptor if current rank pulls from it + if p_remote_rank == remote_rank: + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + + # TODO in case sizes aren't exactly divisible, we may want to swap + # self.block_len with meta.block_len // d_workers_per_p_worker + # (eg when dividing by 3) and handle final block. src_xfer too. + # assert nixl_agent_meta.block_len % self.block_len == 0 + + # Split the kv memory inside a nixl region to guarantee each local + # rank is pulling the kv cache of all layers of a remote worker. + # TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same.. + rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank + print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n") + print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}") + # DECODE TP2 || self.num_blocks=33769, self.block_len=16384, self.num_regions=56 + # PREFILL TP1 || self.num_blocks=17371, self.block_len=32768, self.num_regions=56 + # FIXME assume num_blocks and block_len are actually divisible and all is nice. This needs to be enforced (eg diff mem usage might break) + for base_addr in nixl_agent_meta.kv_caches_base_addr: + base_addr += rank_offset + # for block_id in range(self.num_blocks): + for block_id in range(dst_num_blocks_per_local_rank): + # block_offset = block_id * self.block_len + block_offset = block_id * nixl_agent_meta.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) + # blocks_data.append( + # (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -712,6 +792,7 @@ def _read_blocks( request_id: str, ): # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + # TODO check remote_rank in here too? if dst_engine_id not in self._remote_agents: self._nixl_handshake(remote_host, remote_port) @@ -738,9 +819,14 @@ def _read_blocks( assert num_local_blocks <= num_remote_blocks if num_local_blocks < num_remote_blocks: remote_block_ids = remote_block_ids[-num_local_blocks:] + + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogenous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. From 8db36057d2b172a81ead638e3d177a904f57bbef Mon Sep 17 00:00:00 2001 From: nicklucche Date: Sat, 10 May 2025 09:54:07 +0000 Subject: [PATCH 12/24] split kv_cache along head dim fix descr indexing change remote worker selection indexing; test ptp2-dtp4 Signed-off-by: nicklucche --- vllm/config.py | 1 - .../kv_connector/v1/nixl_connector.py | 197 +++++++++--------- 2 files changed, 101 insertions(+), 97 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4fe287fa320a..c6b97bbdcd66 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3472,7 +3472,6 @@ class KVTransferConfig: kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) """any extra config that the connector may need.""" - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 595bfd96dc37..58e17c83461b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,7 +14,7 @@ from typing_extensions import Optional from vllm import envs -from vllm.config import VllmConfig, KVTransferConfig +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) from vllm.distributed.parallel_state import ( @@ -155,7 +155,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id), vllm_config.kv_transfer_config) + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -349,7 +349,7 @@ def request_finished( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str, kv_config: KVTransferConfig): + def __init__(self, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -367,12 +367,11 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() - # Remote tracking ds only contain one entry for own tp group: engine_id-self.rank # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr. For TP case, each local - # rank will still only pull from a single remote TP worker. + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache @@ -380,8 +379,8 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): self.num_regions = 0 # nixl_prepped_dlist_handle. Different dst TP sizes require preparing - # xfer layout differently. - self.src_xfer_side_handle: int = dict() + # xfer layout differently. + self.src_xfer_side_handle: dict[int, int] = dict() # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[str, int] = dict() @@ -405,7 +404,7 @@ def __init__(self, engine_id: str, kv_config: KVTransferConfig): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None - + self._tp_size = {self.engine_id: self.world_size} @staticmethod @@ -419,7 +418,6 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # move this into the scheduler rather than worker, since # each rank needs the metadata of all other ranks (whereas # in this setup, each rank only gets one other rank's meta. - # TODO iterate over all ranks to handshake with M. Can we get M from config? encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -449,11 +447,12 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(sock, rank: int)->NixlAgentMetadata: + def handshake(sock, rank: int) -> NixlAgentMetadata: # Send query for the request. sock.send(GET_META_MSG) metadata_bytes = sock.recv() @@ -466,29 +465,28 @@ def handshake(sock, rank: int)->NixlAgentMetadata: setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) + got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + setup_agent_time - got_metadata_time) return metadata - # Handshake with remote agent-rank0 first to get the tp_size of remote + # Handshake with remote agent-rank0 first to get the tp_size of remote path = f"tcp://{host}:{port}" logger.debug("Querying master rank metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: - metadata = handshake(sock, 0) - - # TODO should we skip this if remote world_size == world_size (homogeneous)? + metadata = handshake(sock, 0) - # Handshake only with the other TP remote the current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - p_remote_rank = self.rank % metadata.tp_size + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // metadata.tp_size + p_remote_rank = self.rank // d_workers_per_p_worker if p_remote_rank > 0: path = f"tcp://{host}:{port + p_remote_rank}" - logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) with zmq_ctx(zmq.REQ, path) as sock: metadata = handshake(sock, p_remote_rank) - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -502,18 +500,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # MLA case. self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] - # TODO does this include tp dependent size? block_shape = first_kv_cache.shape[-block_rank:] + self.block_size, kv_latent_dim = block_shape + self.kv_dim = kv_elem_size * kv_latent_dim else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] + self.block_size, n_kv_heads, head_dim = block_shape + # head size in bytes. + self.kv_dim = kv_elem_size * n_kv_heads * head_dim # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - print(f"\n\n{self.block_len=}\n\n") logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, first_kv_cache.shape) @@ -541,15 +543,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) - print("************************BLOCKS SETUP") - print(f"Number of blocks {len(kv_caches_base_addr)=}\n") - print(f"{self.num_blocks=}, {self.block_len=}, {self.num_regions=}\n") descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) # After KV Caches registered, listen for new connections. @@ -559,8 +557,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, tp_size=self.world_size, - block_len=self.block_len - ) + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -570,92 +567,96 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0): - # FIXME one other approach I tried is loading half of every remote block instead of half the blocks. Doesnt seem to make much difference + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_rank: int = 0): engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery - if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]: + if (engine_id in self._remote_agents and \ + remote_rank in self._remote_agents[engine_id]): return if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size self._tp_size[engine_id] = nixl_agent_meta.tp_size - self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - - # TODO enforce tp sizes are exact multiples - d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id] - assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP" - dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker + self._remote_agents[engine_id][ + remote_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[engine_id] + assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than" + " prefill TP" + + # TODO we should also check hidden_dim and kv precision, they must match + remote_block_size = nixl_agent_meta.block_len / ( + self.kv_dim * d_workers_per_p_worker) + assert self.block_size == remote_block_size, "Remote P worker with " + "different block size is not supported" # Create src descs and xfer side handles. if d_workers_per_p_worker not in self.src_xfer_side_handle: blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(dst_num_blocks_per_local_rank): - block_offset = block_id * nixl_agent_meta.block_len - # (addr, len, device id) - # use the block size of the dst/P node to make sure regions match - blocks_data.append( - (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for b in range(self.block_size): + head_offset = b * self.kv_dim + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + blocks_data.append((addr, self.kv_dim, self.rank)) logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) + len(blocks_data), self.engine_id, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + self.src_xfer_side_handle[ + d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) - # Create dst descs and xfer side handles. TP workers have same #blocks - # if engine_id in self.dst_num_blocks: - # assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - # self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - # When D_TP>P_TP, P blocks are split between D workers. Hence we may - # record a fraction of the total num_blocks in P. - self.dst_num_blocks[engine_id] = dst_num_blocks_per_local_rank + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks blocks_data = [] - # With heterogenous TP, prepare the descriptors by splitting the P KV - # cache into chunks of D worker's size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[KV_0 | KV_1] (contiguous view). - p_remote_rank = self.rank % nixl_agent_meta.tp_size - # Only register the remote's descriptor if current rank pulls from it + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_rank = self.rank // d_workers_per_p_worker + # Only register the remote's descriptors if current rank pulls from it. if p_remote_rank == remote_rank: - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - - # TODO in case sizes aren't exactly divisible, we may want to swap - # self.block_len with meta.block_len // d_workers_per_p_worker - # (eg when dividing by 3) and handle final block. src_xfer too. - # assert nixl_agent_meta.block_len % self.block_len == 0 - - # Split the kv memory inside a nixl region to guarantee each local - # rank is pulling the kv cache of all layers of a remote worker. - # TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same.. - rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank - print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n") - print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}") - # DECODE TP2 || self.num_blocks=33769, self.block_len=16384, self.num_regions=56 - # PREFILL TP1 || self.num_blocks=17371, self.block_len=32768, self.num_regions=56 - # FIXME assume num_blocks and block_len are actually divisible and all is nice. This needs to be enforced (eg diff mem usage might break) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim + # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: - base_addr += rank_offset - # for block_id in range(self.num_blocks): - for block_id in range(dst_num_blocks_per_local_rank): - # block_offset = block_id * self.block_len + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, nixl_agent_meta.block_len, self.rank)) - # blocks_data.append( - # (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), engine_id, remote_rank, self.rank) + for b in range(self.block_size): + # Remote kv_dim = local kv_dim * d_workers_per_p_worker + head_offset = b * self.kv_dim * d_workers_per_p_worker + addr = base_addr + block_offset + head_offset + # (addr, len, device id) + blocks_data.append( + (addr + rank_offset, self.kv_dim, remote_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: @@ -820,13 +821,15 @@ def _read_blocks( if num_local_blocks < num_remote_blocks: remote_block_ids = remote_block_ids[-num_local_blocks:] - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from - # corresponding rank. With heterogenous TP, fixing D>P, the D tp + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. # Get side handles. - d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] - local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker] + d_workers_per_p_worker = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handle[ + d_workers_per_p_worker] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. @@ -864,7 +867,9 @@ def _get_block_descs_ids(self, engine_id: str, descs_ids: list[int] = [] for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + for kv_block in range(self.block_size): + descs_ids.append(reg_id * num_blocks * self.block_size + + block_id * self.block_size + kv_block) return descs_ids From 7ea6cb28b260c4b8aeeaf103a47efc7fd5f97982 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 12 May 2025 18:46:45 +0800 Subject: [PATCH 13/24] [Misc] Improve modelscope import error (#17983) Signed-off-by: Jee Jee Li --- vllm/transformers_utils/__init__.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py index 01d5bb4b5748..b556976a51ba 100644 --- a/vllm/transformers_utils/__init__.py +++ b/vllm/transformers_utils/__init__.py @@ -3,17 +3,21 @@ from vllm.envs import VLLM_USE_MODELSCOPE if VLLM_USE_MODELSCOPE: - # Patch here, before each import happens - import modelscope - from packaging import version + try: + # Patch here, before each import happens + import modelscope + from packaging import version - # patch_hub begins from modelscope>=1.18.1 - if version.parse(modelscope.__version__) <= version.parse('1.18.0'): - raise ImportError( - 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' - 'install by `pip install modelscope -U`') - - from modelscope.utils.hf_util import patch_hub + # patch_hub begins from modelscope>=1.18.1 + if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + raise ImportError( + 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' + 'install by `pip install modelscope -U`') + from modelscope.utils.hf_util import patch_hub - # Patch hub to download models from modelscope to speed up. - patch_hub() + # Patch hub to download models from modelscope to speed up. + patch_hub() + except ImportError as err: + raise ImportError( + "Please install modelscope>=1.18.1 via " + "`pip install modelscope>=1.18.1` to use ModelScope.") from err From 05a4324f8e3932c25554791ff248e3e0200eef92 Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Mon, 12 May 2025 10:28:58 -0300 Subject: [PATCH 14/24] Initialize the delta tool call fields explicitly (#17340) Signed-off-by: Max de Bayser Co-authored-by: igmainc --- .../entrypoints/openai/tool_parsers/utils.py | 2 +- vllm/entrypoints/chat_utils.py | 4 ++ vllm/entrypoints/openai/protocol.py | 9 +++-- vllm/entrypoints/openai/serving_chat.py | 39 ++++++++++++------- .../granite_20b_fc_tool_parser.py | 4 +- .../tool_parsers/granite_tool_parser.py | 4 +- .../openai/tool_parsers/hermes_tool_parser.py | 4 +- .../tool_parsers/internlm2_tool_parser.py | 4 +- .../openai/tool_parsers/jamba_tool_parser.py | 4 +- .../openai/tool_parsers/llama_tool_parser.py | 4 +- .../tool_parsers/phi4mini_tool_parser.py | 4 +- .../tool_parsers/pythonic_tool_parser.py | 3 +- 12 files changed, 51 insertions(+), 34 deletions(-) diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 6ad5aa26ffa1..ab8f4bd678fd 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -32,7 +32,7 @@ def append_delta(self, delta: DeltaMessage): assert len(delta.tool_calls) < 2, ( "Streaming should include only one tool call per update.") for call_delta in delta.tool_calls: - assert call_delta.type == "function", ( + assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " f"{call_delta.type}") current_tool_call = self.tool_calls[ diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index db43b2dd295d..4ff8821fca54 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -44,6 +44,7 @@ # yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid logger = init_logger(__name__) @@ -1272,3 +1273,6 @@ def apply_mistral_chat_template( "An error occurred in `mistral_common` while applying chat " "template") raise ValueError from e + +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{random_uuid()}" \ No newline at end of file diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4e09240f23af..19c426b19fe2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -15,7 +15,8 @@ from typing_extensions import TypeAlias from vllm import envs -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + random_tool_call_id) from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -1339,7 +1340,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + id: str = Field(default_factory=random_tool_call_id) type: Literal["function"] = "function" function: FunctionCall @@ -1351,8 +1352,8 @@ class DeltaFunctionCall(BaseModel): # a tool call delta where everything is optional class DeltaToolCall(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") - type: Literal["function"] = "function" + id: Optional[str] = None + type: Optional[Literal["function"]] = None index: int function: Optional[DeltaFunctionCall] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5c11836fbff4..30f8aade086d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -16,7 +16,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - ConversationMessage) + ConversationMessage, + random_tool_call_id) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -363,9 +364,10 @@ def extract_tool_call_required_streaming( function_name_returned = True delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments), + DeltaToolCall(id=random_tool_call_id(), + function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments), index=len(obj) - 1, type="function") ]) @@ -382,8 +384,7 @@ def extract_tool_call_required_streaming( # instead of name every time name=None, arguments=delta_text), - index=len(obj) - 1, - type="function") + index=len(obj) - 1) ]) else: delta_message = None @@ -422,7 +423,7 @@ async def chat_completion_stream_generator( and self._should_stream_with_auto_tool_parsing(request)) all_previous_token_ids: Optional[list[list[int]]] - function_name_returned: Optional[list[bool]] = None + function_name_returned = [False] * num_choices # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. @@ -435,7 +436,6 @@ async def chat_completion_stream_generator( reasoning_end_arr = [False] * num_choices elif request.tool_choice == "required": previous_texts = [""] * num_choices - function_name_returned = [False] * num_choices all_previous_token_ids = None else: previous_texts, all_previous_token_ids = None, None @@ -623,16 +623,27 @@ async def chat_completion_stream_generator( delta_text = previous_text + delta_text current_text = "" + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall( + arguments=delta_text), + index=i) + else: + delta_tool_call = DeltaToolCall( + id=random_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + function_name_returned[i] = True + delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=tool_choice_function_name, - arguments=delta_text), - index=i) + delta_tool_call, ]) elif request.tool_choice == "required": assert previous_texts is not None - assert function_name_returned is not None previous_text = previous_texts[i] current_text = previous_text + delta_text fn_name_returned = function_name_returned[i] @@ -835,7 +846,7 @@ async def chat_completion_stream_generator( total_tokens=num_prompt_tokens + completion_tokens, ) - data = chunk.model_dump_json(exclude_unset=True) + data = chunk.model_dump_json(exclude_none=True) yield f"data: {data}\n\n" # once the final token is handled, if stream_options.include_usage diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 76da63c58008..b93de6b41817 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -9,6 +9,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -22,7 +23,6 @@ partial_json_loads) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -200,7 +200,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index 91afc88ef3dd..6710e7938c43 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -7,6 +7,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -20,7 +21,6 @@ partial_json_loads) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -182,7 +182,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 4c39e9b0c61f..e56a8ef7193c 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -8,6 +8,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -17,7 +18,6 @@ ToolParser, ToolParserManager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -259,7 +259,7 @@ def extract_tool_calls_streaming( return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 57d7c77c64f7..5abd553d884d 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -7,6 +7,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -18,7 +19,6 @@ extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -106,7 +106,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 8df106bf2718..6cac6f8163bf 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -8,6 +8,7 @@ import partial_json_parser from partial_json_parser.core.options import Allow +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -19,7 +20,6 @@ from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -220,7 +220,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 5c181616aa01..9307034f40d6 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,6 +10,7 @@ from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -21,7 +22,6 @@ is_complete_json, partial_json_loads) from vllm.logger import init_logger -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -208,7 +208,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 668776a832e2..abf70a5e85c4 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -7,6 +7,7 @@ from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -14,7 +15,6 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) from vllm.logger import init_logger -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -73,7 +73,7 @@ def extract_tool_calls( tool_calls: list[ToolCall] = [ ToolCall( - id=f"chatcmpl-tool-{random_uuid()}", + id=random_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 9f141d6b334b..bb91a35af3be 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, new_call_args = new_call_args[:-len(withheld_suffix)] if not previously_sent_args: return DeltaToolCall(id=new_call.id, + type="function", index=index, function=DeltaFunctionCall( name=new_call.function.name, @@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, arg_diff = new_call_args[len(previously_sent_args):] return DeltaToolCall( - id="", index=index, function=DeltaFunctionCall( + id=None, index=index, function=DeltaFunctionCall( arguments=arg_diff)) if arg_diff else None From d19110204c03e9b77ed957fc70c1262ff370f5e2 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 12 May 2025 12:46:16 -0400 Subject: [PATCH 15/24] [P/D] NIXL Integration (#17751) Signed-off-by: ApostaC Signed-off-by: Tyler Michael Smith Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Robert Shaw Signed-off-by: mgoin Signed-off-by: Nick Hill Signed-off-by: Brent Salisbury Co-authored-by: Tyler Michael Smith Co-authored-by: ApostaC Co-authored-by: Robert Shaw Co-authored-by: mgoin Co-authored-by: Nick Hill Co-authored-by: Tyler Michael Smith Co-authored-by: Brent Salisbury --- .buildkite/test-pipeline.yaml | 1 + tests/v1/core/test_scheduler.py | 6 +- .../nixl_integration/run_accuracy_test.sh | 171 ++++ .../nixl_integration/run_edge_case_test.sh | 123 +++ .../nixl_integration/test_accuracy.py | 60 ++ .../nixl_integration/test_edge_cases.py | 77 ++ .../nixl_integration/toy_proxy_server.py | 260 ++++++ tests/v1/kv_connector/unit/__init__.py | 0 .../kv_connector/unit/test_nixl_connector.py | 73 ++ .../unit/test_remote_decode_lifecycle.py | 181 ++++ .../unit/test_remote_prefill_lifecycle.py | 342 ++++++++ tests/v1/kv_connector/unit/utils.py | 190 +++++ vllm/config.py | 6 +- .../kv_transfer/kv_connector/factory.py | 5 + .../kv_transfer/kv_connector/v1/__init__.py | 7 +- .../kv_transfer/kv_connector/v1/base.py | 89 +- .../kv_connector/v1/lmcache_connector.py | 6 +- .../kv_connector/v1/nixl_connector.py | 805 ++++++++++++++++++ .../v1/shared_storage_connector.py | 12 +- vllm/entrypoints/openai/protocol.py | 19 +- vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/envs.py | 10 + vllm/forward_context.py | 21 - vllm/outputs.py | 6 +- vllm/v1/core/kv_cache_manager.py | 33 +- vllm/v1/core/sched/interface.py | 4 + vllm/v1/core/sched/scheduler.py | 188 +++- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/core.py | 9 + vllm/v1/engine/output_processor.py | 12 +- vllm/v1/outputs.py | 22 +- vllm/v1/request.py | 13 +- vllm/v1/worker/gpu_model_runner.py | 78 +- 34 files changed, 2724 insertions(+), 109 deletions(-) create mode 100755 tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh create mode 100644 tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh create mode 100644 tests/v1/kv_connector/nixl_integration/test_accuracy.py create mode 100644 tests/v1/kv_connector/nixl_integration/test_edge_cases.py create mode 100644 tests/v1/kv_connector/nixl_integration/toy_proxy_server.py create mode 100644 tests/v1/kv_connector/unit/__init__.py create mode 100644 tests/v1/kv_connector/unit/test_nixl_connector.py create mode 100644 tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py create mode 100644 tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py create mode 100644 tests/v1/kv_connector/unit/utils.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f7e4af4f2af4..027cb218df5e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -214,6 +214,7 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0ca2ced89148..f40d477a0036 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -870,7 +870,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) ###################################################### # FIRST SET OF REQUESTS - External Hit Only @@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. @@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS) + NUM_MATCHED_NEW_TOKENS, False) # Create two requests. # Both can be scheduled at first, but the second request diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh new file mode 100755 index 000000000000..e90b72a7cf24 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -0,0 +1,171 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Number of prefill and decode instances to create +NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 +NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Arrays to store all hosts and ports + PREFILL_HOSTS=() + PREFILL_PORTS=() + DECODE_HOSTS=() + DECODE_PORTS=() + + # Start prefill instances + for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs + GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8100 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5559 + i)) + + echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + PREFILL_HOSTS+=("localhost") + PREFILL_PORTS+=($PORT) + done + + # Start decode instances + for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs + GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8200 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5659 + i)) + + echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + DECODE_HOSTS+=("localhost") + DECODE_PORTS+=($PORT) + done + + # Wait for all instances to start + for PORT in "${PREFILL_PORTS[@]}"; do + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PORT + done + + for PORT in "${DECODE_PORTS[@]}"; do + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $PORT + done + + # Build the command for the proxy server with all the hosts and ports + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" + + # Add all prefill hosts and ports + PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" + + # Add all decode hosts and ports + PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" + + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh new file mode 100644 index 000000000000..98903a176e28 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen3-0.6B" +) + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Start prefill instance + PREFILL_PORT=8001 + + BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ + --port $PREFILL_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Start decode instance + DECODE_PORT=8002 + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ + --port $DECODE_PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Wait for all instances to start + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $DECODE_PORT + + # Build the command for the proxy server with all the hosts and ports + PROXY_PORT=8192 + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORT}" + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py new file mode 100644 index 000000000000..be2d84f3bb17 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import lm_eval +import openai + +BASE_URL = "http://localhost:8192/v1" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 + +# Model-specific expected values +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, +} + +SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 + +# Get model name from environment variable +MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") + + +def run_simple_prompt(): + client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) + completion = client.completions.create(model=MODEL_NAME, + prompt=SIMPLE_PROMPT) + + print("-" * 50) + print(f"Completion results for {MODEL_NAME}:") + print(completion) + print("-" * 50) + + +def test_accuracy(): + """Run the end to end accuracy test.""" + run_simple_prompt() + + model_args = (f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + expected_value = EXPECTED_VALUES.get(MODEL_NAME) + + if expected_value is None: + print(f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check.") + print(f"Measured value: {measured_value}") + return + + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py new file mode 100644 index 000000000000..5363fbde0096 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import openai + +PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_PORT = os.getenv("PROXY_PORT", None) + +if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: + raise ValueError( + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + +LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 +PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 +SHORT_PROMPT = "Red Hat is " + + +def test_edge_cases(): + # Set the OpenAI API key and base URL + decode_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{DECODE_PORT}/v1", + ) + prefill_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PREFILL_PORT}/v1", + ) + proxy_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PROXY_PORT}/v1", + ) + + # Get the list of models + models = decode_client.models.list() + MODEL = models.data[0].id + + # (1) Check that we can handle a very short prompt, + # less than the length of the block size. + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"SMALL PROMPT: {proxy_response=}") + assert proxy_response == prefill_response + + # (2) Check that we can handle a full prefix cache + # hit on the D worker but not on the P worker. + # (2a): prime the D worker. + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + decode_response = completion.choices[0].text + # (2b): send via the P/D setup + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + print(f"FULL CACHE HIT: {proxy_response=}") + assert proxy_response == decode_response + + # (3) Check that we can handle a partial prefix cache + # hit on the D worker. + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"PARTIAL CACHE HIT: {proxy_response=}") + assert proxy_response == prefill_response diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py new file mode 100644 index 000000000000..13071f581375 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/unit/__init__.py b/tests/v1/kv_connector/unit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py new file mode 100644 index 000000000000..9b2a720c11c4 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata) + +from .utils import create_request, create_scheduler, create_vllm_config + + +def test_basic_inferface(): + """Unit test for basic NixlConnector interface functionality.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetdata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, scheduler.kv_cache_manager. + single_type_manager.req_to_blocks[request_id]): + assert block_id == block.block_id + + +def test_prompt_less_than_block_size(): + """ + Test that we can handle case where prompt is < block. + + In this case, the P worker will send empty remote_block_ids. + The D worker should not schedule an async read in this case, + since there is nothing to pull. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Half of a block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = int(BLOCK_SIZE * 0.5) + + # Request will have 0 remote blocks. + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=0) + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + + # This request should not have to read async. + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + assert len(kv_connector_metadata.requests) == 0 + + # This request should be scheduled regularly. + assert len(scheduler_output.scheduled_new_reqs) == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py new file mode 100644 index 000000000000..77098140343a --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + output = engine_core_outputs.outputs[0] + assert output.finish_reason == FinishReason.LENGTH + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and in Persistent Batch ... + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + # ... but blocks should not be freed. + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + + # STEP (2): Send Finished to PB. + # (2a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 1 + assert request_id in scheduler_output.finished_req_ids + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (2b): execute_model() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (2c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP (3): Finished sending. + # (3a): schedule() - pass finished request to PB. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 0 + assert len(scheduler_output.finished_req_ids) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler.finished_req_ids) == 0 + + # (3b): execute_model() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_id] + + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + assert_scheduler_empty(scheduler) + + +def test_short_prompt_lifecycle(): + """Test lifecycle of a Remote Decode request with short prompt.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Not enough tokens for full block. + NUM_TOKENS = vllm_config.cache_config.block_size // 2 + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + # Since tokens < block_size, there will be no kv xfer. + # So this should be cleaned up immediately. + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + # We need one more call to schedule() to clear data for persistent batch. + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_prefix_cache_lifecycle(): + """Test that remote decode params still works with a prefix cache hit.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 3 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + ##################### + # Actual Test: confirm we send all blocks. + + # Step (1): Send the KV Transfer. + NUM_EXTERNAL_FULL_BLOCKS -= 1 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco.outputs[0].kv_transfer_params + + # Ensure we send all block ids, even if there is a cache hit. + assert (len( + kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS) + + # STEP (2): Ensure it is freed. + scheduler_output = scheduler.schedule() + scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request_remote.request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py new file mode 100644 index 000000000000..fc4928f9ebd1 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from .utils import (assert_scheduler_empty, create_model_runner_output, + create_request, create_scheduler, create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a remote prefill.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks + < START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(engine_core_outputs.outputs) == 0 + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_id] + for block in blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_interleaved_lifecycle(): + """Test Remote Prefills Work Well With Other Requests.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_local_a = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + ) + request_local_b = create_request( + request_id=3, + num_tokens=NUM_TOKENS, + ) + + # STEP 1: Regular request is running. + scheduler.add_request(request_local_a) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + model_runner_output = create_model_runner_output([request_local_a]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 2: Add a local and remote request. + scheduler.add_request(request_local_b) + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 1 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 3: continue running, KVs not arrived yet. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + reqs=[request_local_a, request_local_b]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + # STEP 4: KVs arrive. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b], + finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 5: RECVed KVs are sent to ModelRunner. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(scheduler.waiting) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_cached_reqs) == 2 + + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # STEP 6: Hit EOS and free. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + [request_local_a, request_local_b, request_remote], + use_eos=True, + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + request_local = create_request( + request_id=2, + num_tokens=NUM_TOKENS, + do_remote_prefill=False, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + # Schedule the local prefill request. This should + # cause blocks to be cached, but separately from + scheduler.add_request(request_local) + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ + request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501 + request_remote.request_id] + + # Local should have cached blocks (but not all due to preallocate). + num_hashed_blocks = 0 + for block in local_blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks > 0 + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_full_block_prompt(): + """Test that we handle a prompt that is the full block size.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Initialize a recv. + scheduler_output = scheduler.schedule() + # All blocks should be allocated. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # STEP (2): Recv. + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # # STEP (3): Run as usual. + scheduler_output = scheduler.schedule() + + # We need to recompute the final token of the prompt to generate + # the first new token, so we should not have a new block. + num_blocks = len(scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_TOKENS - 1) + assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + + model_runner_output = create_model_runner_output([request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs.outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py new file mode 100644 index 000000000000..8a7d7bdd83da --- /dev/null +++ b/tests/v1/kv_connector/unit/utils.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm import SamplingParams +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlKVTransferParams) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_kv_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, +) -> Request: + """Make dummy request for testing.""" + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False, + do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = NixlKVTransferParams( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234) + else: + kv_transfer_params = None + + max_tokens = 1 if do_remote_decode else max_tokens + sampling_params = SamplingParams(max_tokens=max_tokens) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/vllm/config.py b/vllm/config.py index 4a503665503a..c6b97bbdcd66 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -8,6 +8,7 @@ import json import re import textwrap +import uuid import warnings from collections import Counter from contextlib import contextmanager @@ -3438,6 +3439,9 @@ class KVTransferConfig: """The KV connector for vLLM to transmit KV caches between vLLM instances. """ + engine_id: str = str(uuid.uuid4()) + """The engine id for KV transfers.""" + kv_buffer_device: Optional[str] = "cuda" """The device used by kv connector to buffer the KV cache. Currently only support 'cuda'.""" @@ -3448,7 +3452,7 @@ class KVTransferConfig: kv_role: Optional[KVRole] = None """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'both'.""" + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" kv_rank: Optional[int] = None """The rank of this vLLM instance in the KV cache transfer. Typical value: diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..54cb1871db3c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -105,3 +105,8 @@ def create_connector_v1( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index a017b140e090..43181ab79afc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorRole) + KVConnectorBase_V1, KVConnectorRole, KVTransferParams) -__all__ = [ - "KVConnectorRole", - "KVConnectorBase_V1", -] +__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..2ff61e8a400f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -23,7 +23,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import torch @@ -34,6 +34,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +class KVTransferParams: + """ + Abstract KVTransferParams used to send KVTransfer + parameters between instances of vLLM. + + Specific instances of KVConnector customize this + method for serializing / deserializing msgs sent + via the HTTP protocol. + """ + + @staticmethod + def from_raw_dict( + raw_dict: Optional[dict[str, + Any]]) -> Optional["KVTransferParams"]: + return None + + @dataclass class KVConnectorMetadata: + """ + Abstract Metadata used to communicate between the + Scheduler KVConnector and Worker KVConnector. + """ pass class KVConnectorBase_V1(ABC): + _KVTransferParams = KVTransferParams def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( @@ -66,6 +89,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def role(self) -> KVConnectorRole: return self._role + # ============================== + # Worker-side methods + # ============================== + def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. @@ -97,9 +124,15 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: """ return self._connector_metadata - # ============================== - # Worker-side methods - # ============================== + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", @@ -162,15 +195,37 @@ def wait_for_save(self): """ pass + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous (recving, sending). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + # ============================== # Scheduler-side methods # ============================== + + def set_kv_transfer_params(self, request: "Request"): + """Parse raw KV Transfer params.""" + assert request.kv_transfer_params is None + kv_transfer_params = self._KVTransferParams.from_raw_dict( + request.raw_kv_transfer_params) + request.kv_transfer_params = kv_transfer_params + @abstractmethod def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -181,13 +236,16 @@ def get_num_new_matched_tokens( computed tokens for this request Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if external KV cache tokens will be loaded + asynchronously (between scheduler steps). """ pass @abstractmethod def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. @@ -207,3 +265,20 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ pass + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return False, None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e07f185f0dd8..2cb68dc1ff67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -92,7 +93,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -107,9 +108,10 @@ def get_num_new_matched_tokens( external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens), False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 000000000000..d26184982270 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,805 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import math +import threading +import time +import uuid +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Iterator + +import msgspec +import torch +import zmq +from typing_extensions import Optional + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.utils import round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +@dataclass +class NixlKVTransferParams(KVTransferParams): + + def __init__( + self, + do_remote_prefill: bool, + do_remote_decode: bool, + remote_block_ids: Optional[list[int]] = None, + remote_host: Optional[str] = None, + remote_port: Optional[int] = None, + remote_engine_id: Optional[str] = None, + ): + self.do_remote_prefill = do_remote_prefill + self.do_remote_decode = do_remote_decode + self.remote_block_ids = remote_block_ids + self.remote_host = remote_host + self.remote_port = remote_port + self.remote_engine_id = remote_engine_id + + @staticmethod + def from_raw_dict( + raw_dict: Optional[dict[str, + Any]]) -> Optional["NixlKVTransferParams"]: + + # If no raw transfer params passed, return None. + if raw_dict is None: + return None + + # Validate the request is formatted properly. + if (("do_remote_prefill" not in raw_dict) + or ("do_remote_decode" not in raw_dict) + or ("remote_block_ids" not in raw_dict) + or ("remote_host" not in raw_dict) + or ("remote_port" not in raw_dict) + or ("remote_engine_id" not in raw_dict)): + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", raw_dict) + return None + + return NixlKVTransferParams( + do_remote_prefill=raw_dict["do_remote_prefill"], + do_remote_decode=raw_dict["do_remote_decode"], + remote_block_ids=raw_dict["remote_block_ids"], + remote_host=raw_dict["remote_host"], + remote_port=raw_dict["remote_port"], + remote_engine_id=raw_dict["remote_engine_id"], + ) + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: NixlKVTransferParams, + ): + assert request_id not in self.requests + assert kv_transfer_params.remote_block_ids is not None + assert kv_transfer_params.remote_engine_id is not None + assert kv_transfer_params.remote_host is not None + assert kv_transfer_params.remote_port is not None + + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params.remote_block_ids, + remote_engine_id=kv_transfer_params.remote_engine_id, + remote_host=kv_transfer_params.remote_host, + remote_port=kv_transfer_params.remote_port, + ) + + +class NixlConnector(KVConnectorBase_V1): + _KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + pass + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + # No KVTransfer for this request. + if request.kv_transfer_params is None: + return 0, False + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + + # Remote prefill: get all prompt blocks from remote. + if request.kv_transfer_params.do_remote_prefill: + assert num_computed_tokens % self.block_size == 0 + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + return count, count > 0 + + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if request.kv_transfer_params is None: + return + + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + if request.kv_transfer_params.do_remote_prefill: + # NOTE(rob): if prompt < block_size, no remote blocks + # since the remote only sends fully computed blocks, so + # skip recving for this request. num_external_tokens + # should be 0 if there are no remote blocks. + if request.kv_transfer_params.remote_block_ids: + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + request.kv_transfer_params.do_remote_prefill = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert isinstance(req.kv_transfer_params, NixlKVTransferParams) + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + if request.kv_transfer_params is None: + return False, None + assert isinstance(request.kv_transfer_params, NixlKVTransferParams) + + if ((not request.kv_transfer_params.do_remote_decode) + or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)): + return False, None + + # Get computed blocks. + all_full = request.num_computed_tokens % self.block_size == 0 + computed_block_ids = (block_ids if all_full else block_ids[:-1]) + + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + return delay_free_blocks, NixlKVTransferParams( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, + remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + ).__dict__ + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> agent_name. + self._remote_agents: dict[str, str] = {} + + # Metadata. + self.engine_id = engine_id + self.rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr + self.kv_caches_base_addr: dict[str, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + + # nixl_prepped_dlist_handle (int). + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[str, int] = {} + + # Map of engine_id -> num_blocks. + self.dst_num_blocks: dict[str, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( + list[Any]) + + # Complete transfer tracker. Used by the rank 0 to track finished + # transactions on ranks 1 to N-1. + # [req_id -> count] + self._done_recving_count: defaultdict[str, + int] = defaultdict(lambda: 0) + self._done_sending_count: defaultdict[str, + int] = defaultdict(lambda: 0) + + # Background thread for establishing new connections. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, rank: int): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach like an ETCD server in the future. + + # NOTE(rob): to support heterogeneous TP, we will have to + # move this into the scheduler rather than worker, since + # each rank needs the metadata of all other ranks (whereas + # in this setup, each rank only gets one other rank's meta. + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + # NOTE(rob): we need each rank to have a unique port. This + # hack to keeps us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank + path = f"tcp://{host}:{port}" + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake(self, host: str, port: int): + """Do a NIXL handshake with a remote instance.""" + + start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + path = f"tcp://{host}:{port + self.rank}" + logger.debug("Querying metadata on path: %s", path) + with zmq_ctx(zmq.REQ, path) as sock: + # Send query for the request. + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + use_mla = len(first_kv_cache.shape) == 3 + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + else: + # [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + self.block_len = kv_elem_size * math.prod(block_shape) + + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, + first_kv_cache.shape) + logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + + self._registered_descs.append(descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + ) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() + + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + engine_id = nixl_agent_meta.engine_id + if engine_id in self._remote_agents: + return + + self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + + # Create src descs and xfer side handles. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # Create dst descs and xfer side handles. + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] + for base_addr in self.kv_caches_base_addr[engine_id]: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * self.block_len + # (addr, len, device id) + blocks_data.append( + (base_addr + block_offset, self.block_len, self.rank)) + logger.debug("Created %s blocks for dst engine %s and rank %s", + len(blocks_data), engine_id, self.rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id], descs) + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + + In TP>1 setup, each rank exchanges KVs with its counterpart + ranks independently. get_finished() runs in a worker creates + the done_sending and done_recving sets that are sent to the + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", self.rank, len(done_sending), + len(done_recving)) + + if self.world_size == 1: + return done_sending, done_recving + + # Rank 0: get finished from all other ranks. + if self.rank == 0: + for req_id in done_sending: + self._done_sending_count[req_id] += 1 + for req_id in done_recving: + self._done_recving_count[req_id] += 1 + + # Keep track of how many other ranks have finished. + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + # Return ids that finished on all ranks to the scheduler. + all_done_recving: set[str] = set() + for req_id in list(self._done_recving_count.keys()): + if self._done_recving_count[req_id] == self.world_size: + del self._done_recving_count[req_id] + all_done_recving.add(req_id) + + all_done_sending: set[str] = set() + for req_id in list(self._done_sending_count.keys()): + if self._done_sending_count[req_id] == self.world_size: + del self._done_sending_count[req_id] + all_done_sending.add(req_id) + + return all_done_sending, all_done_recving + + # Ranks 1 to N-1: send finished ids to Rank 0. + else: + finished_req_ids = list(done_recving.union(done_sending)) + self.tp_group.send_object(finished_req_ids, dst=0) + + # Unused as only Rank 0 results are sent to scheduler. + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """Get req_ids which got a remote xfer message.""" + + notified_req_ids: set[str] = set() + for req_ids in self.nixl_wrapper.get_new_notifs().values(): + for req_id in req_ids: + assert req_id not in notified_req_ids + notified_req_ids.add(req_id.decode("utf-8")) + return notified_req_ids + + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # TODO ptarasiewicz: why abort is throwing errors? + # self.nixl_wrapper.release_xfer_handle(handle) + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if len(running_reqs) == 0: + done_req_ids.add(req_id) + del transfers[req_id] + else: + transfers[req_id] = running_reqs + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_port=meta.remote_port, + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_host: str, + remote_port: int, + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + if dst_engine_id not in self._remote_agents: + self._nixl_handshake(remote_host, remote_port) + + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + self.nixl_wrapper.send_notif(dst_engine_id, + notif_msg=request_id.encode("utf-8")) + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # Get descs ids. + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=request_id.encode("utf-8"), + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append(handle) + + def _get_block_descs_ids(self, engine_id: str, + block_ids: list[int]) -> list[int]: + """Get the descs ids for a set of block ids.""" + + # range(1) for MLA, range(2) otherwise. + region_ids = range(self.num_regions) + num_blocks = self.dst_num_blocks[engine_id] + + # Compute the desc ids for each block. + descs_ids: list[int] = [] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: + socket = ctx.socket(zmq.ROUTER) + socket.bind(addr) + elif socket_type == zmq.REQ: + socket = ctx.socket(zmq.REQ) + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..0fedb6fd5ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -132,8 +133,7 @@ def inject_kv_into_layer( dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, SharedStorageConnectorMetadata) if metadata is None: @@ -225,7 +225,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -239,7 +239,6 @@ def get_num_new_matched_tokens( the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ - # NOTE: in this debug implementation, we assume that the prompt is # cached_prompt + newly_generated_single_token # Therefore, we use prompt_token_ids[:-1] to determine the folder name @@ -248,7 +247,7 @@ def get_num_new_matched_tokens( # with the block granularity. And it expects the returned blocks and # num_computed_tokens to also be aligned with the block granularity. if not self._found_match_for_request(request): - return 0 + return 0, False logger.info("External Cache Hit!") @@ -257,9 +256,10 @@ def get_num_new_matched_tokens( num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - return num_tokens_to_check - num_computed_tokens + return num_tokens_to_check - num_computed_tokens, False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 19c426b19fe2..8ac6534875dd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") # doc: end-chat-completion-extra-params @@ -540,7 +543,9 @@ def to_sampling_params( output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, - logit_bias=self.logit_bias) + logit_bias=self.logit_bias, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) def _get_guided_json_from_tool( self) -> Optional[Union[str, dict, BaseModel]]: @@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") + # doc: end-completion-extra-params # Default sampling parameters for completion requests @@ -973,7 +982,9 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids) + allowed_token_ids=self.allowed_token_ids, + extra_args=({"kv_transfer_params": self.kv_transfer_params} + if self.kv_transfer_params else None)) @model_validator(mode="before") @classmethod @@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel): model: str choices: list[CompletionResponseChoice] usage: UsageInfo + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel): choices: list[ChatCompletionResponseChoice] usage: UsageInfo prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") class DeltaMessage(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 30f8aade086d..a9ba0e4d68ce 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1086,6 +1086,7 @@ async def chat_completion_full_generator( choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, ) return response diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..0b3bdf7d4821 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -482,7 +482,7 @@ def request_output_to_completion_response( model=model_name, choices=choices, usage=usage, - ) + kv_transfer_params=final_res_batch[0].kv_transfer_params) def _create_completion_logprobs( self, diff --git a/vllm/envs.py b/vllm/envs.py index d7f332cb0a73..b3faad03d345 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,6 +112,8 @@ VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False + VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" + VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 def get_default_cache_root(): @@ -747,6 +749,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # insecure method and it is needed for some reason. "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), + + # IP address used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_HOST": + lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), + + # Port used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_PORT": + lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), } # end-env-vars-definition diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9ddc3d1f2c51..eb1e1f5694bb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,10 +11,6 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context diff --git a/vllm/outputs.py b/vllm/outputs.py index 65a6ed01451d..6cd60575b00d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,7 +4,7 @@ from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass -from typing import Generic, Optional, Union +from typing import Any, Generic, Optional, Union import torch from typing_extensions import TypeVar, deprecated @@ -103,6 +103,7 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. + kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -120,6 +121,7 @@ def __init__( num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -133,11 +135,13 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.kv_transfer_params = kv_transfer_params def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished + self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ad8468a89dc5..27368374ea8d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -36,6 +36,12 @@ def get_block_ids(self) -> list[int]: """Converts the KVCacheBlocks instance to a list of block IDs.""" return [block.block_id for block in self.blocks] + def get_unhashed_block_ids(self) -> list[int]: + """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + return [ + block.block_id for block in self.blocks if block.block_hash is None + ] + class KVCacheManager: @@ -116,6 +122,12 @@ def get_computed_blocks(self, - The number of computed tokens. """ + # Request already has blocks from async load via KVConnector. + num_existing_blocks = len( + self.single_type_manager.req_to_blocks[request.request_id]) + if num_existing_blocks > 0: + return KVCacheBlocks.create_empty(), request.num_computed_tokens + # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching @@ -173,6 +185,7 @@ def allocate_slots( num_new_tokens: int, new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -186,6 +199,9 @@ def allocate_slots( num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + delay_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer + which will complete in a future step. Blocks layout: ``` @@ -255,7 +271,9 @@ def allocate_slots( new_blocks = self.single_type_manager.allocate_new_blocks( request.request_id, num_tokens_need_slot) - if not self.enable_caching: + # P/D: delay caching blocks if we have to recv from + # remote. Update state for locally cached blocks. + if not self.enable_caching or delay_cache_blocks: return KVCacheBlocks(new_blocks) # Speculated tokens might be rejected in the future, so we does @@ -350,3 +368,16 @@ def take_events(self) -> list[KVCacheEvent]: A list of KV cache events. """ return self.block_pool.take_events() + + def get_block_ids(self, request_id: str) -> list[int]: + """Get the block ids of a request.""" + assert request_id in self.single_type_manager.req_to_blocks + return [ + block.block_id + for block in self.single_type_manager.req_to_blocks[request_id] + ] + + def get_num_blocks(self, request_id: str): + """Get the number of blocks.""" + assert request_id in self.single_type_manager.req_to_blocks + return len(self.single_type_manager.req_to_blocks[request_id]) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 0b328f510903..c17f80b6ae78 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats @@ -137,3 +138,6 @@ def make_stats(self) -> Optional["SchedulerStats"]: def shutdown(self) -> None: """Shutdown the scheduler.""" raise NotImplementedError + + def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: + return None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 258e0d570e3e..7773853b096a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,13 +5,15 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole, + KVTransferParams) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -96,6 +98,9 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # P/D: requests in process of recving KV transfers + self.finished_recving_kv_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -307,6 +312,16 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] + # P/D: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -330,49 +345,55 @@ def schedule(self) -> SchedulerOutput: continue # Get already-cached tokens. - computed_blocks, num_computed_tokens = \ + new_computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) # Get externally-cached tokens if using a KVConnector. - num_external_tokens = ( - 0 if self.connector is None else + num_external_tokens, load_kv_async = ( + (0, False) if self.connector is None else self.connector.get_num_new_matched_tokens( request, num_computed_tokens)) # Total computed tokens (local + external). num_computed_tokens += num_external_tokens + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # P/D: loading remote KV, do not allocate for new work. + if load_kv_async: + num_new_tokens = 0 # Number of tokens to be scheduled. - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed requests, - # which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled. - break else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_tokens, - computed_blocks, + new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, ) if new_blocks is None: # The request cannot be scheduled. @@ -384,10 +405,18 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, + new_computed_blocks + new_blocks, num_external_tokens, ) self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + if request.use_structured_output: structured_output_request_ids[ request.request_id] = req_index @@ -407,7 +436,7 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = ( - computed_blocks + new_blocks).get_block_ids() + self.kv_cache_manager.get_block_ids(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -698,6 +727,7 @@ def update_from_output( stopped = False new_logprobs = None new_token_ids = generated_token_ids + kv_transfer_params = None # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner @@ -709,7 +739,7 @@ def update_from_output( # This must be called before we make the EngineCoreOutput. stopped = check_stop(request, self.max_model_len) if stopped: - self._free_request(request) + kv_transfer_params = self._free_request(request) del new_token_ids[num_new:] # Trim new tokens if needed. break @@ -739,7 +769,8 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids: + if new_token_ids or kv_transfer_params: + # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -749,7 +780,10 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -757,6 +791,9 @@ def update_from_output( if not stopped: new_running.append(request) + # P/D: update state for finished KV Transfers. + self._update_from_kv_xfer_finished(model_runner_output) + # Return the cached request data to the queue so they can be reused. for req_data in scheduler_output.scheduled_cached_reqs: # NOTE(rob): since we free stopped reqs above, adding stopped reqs @@ -811,15 +848,27 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> None: + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] + def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) @@ -863,3 +912,70 @@ def make_spec_decoding_stats( def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + + ######################################################################## + # P/D Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[KVTransferParams]]: + """Invoke the KV connector request_finished() method if applicable.""" + if self.connector is None: + return False, None + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + P/D: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + self.kv_cache_manager.single_type_manager.cache_blocks( + request, + self.kv_cache_manager.req_to_block_hashes[request.request_id], + num_computed_tokens, + ) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + model_runner_output: ModelRunnerOutput): + """ + P/D: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + scheduler the request during the next step. + """ + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dcd..122a5a72cc36 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -105,6 +105,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + kv_transfer_params: Optional[dict[str, Any]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d9dd4957cff2..c1aa0ce27d3f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -182,6 +182,15 @@ def add_request(self, request: EngineCoreRequest): # Start grammar compilation asynchronously self.structured_output_manager.grammar_init(req) + if req.raw_kv_transfer_params is not None: + if (kv_connector := self.scheduler.get_kv_connector()): + # Parse raw KV transfer params via connector. + kv_connector.set_kv_transfer_params(req) + else: + logger.warning( + "Got KVTransferParams, but no KVConnector found. " + "Disabling KVTransfer for this request.") + self.scheduler.add_request(req) def abort_requests(self, request_ids: list[str]): diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5f5ffe6e09db..a7a9b0e4a161 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind @@ -146,6 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -167,13 +168,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + kv_transfer_params: Optional[dict[str, Any]] = None, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -189,6 +192,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + kv_transfer_params=kv_transfer_params, ) def _new_completion_output( @@ -335,6 +339,7 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + kv_transfer_params = engine_core_output.kv_transfer_params req_state.is_prefilling = False @@ -350,7 +355,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..e8ce0df5ed8d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,12 +100,16 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, -) + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7d..fc6b738546f4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import is_list_of @@ -61,6 +62,15 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D: KV transfer parameters (raw and parsed). + raw_params = (None if sampling_params.extra_args is None + else sampling_params.extra_args.get( + "kv_transfer_params", None)) + self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params + # Each connector parses the raw dictionary and sets this + # attr the first time that the request is processed. + self.kv_transfer_params: Optional[KVTransferParams] = None + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: @@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() + WAITING_FOR_REMOTE_KVS = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdb1339cddca..bd833735b695 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import gc import time import weakref @@ -17,8 +18,9 @@ get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_pp_group, graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model @@ -1065,15 +1067,14 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1150,17 +1151,23 @@ def execute_model( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - output = self.model( + self.maybe_setup_kv_connector(scheduler_output) + + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = output + hidden_states, aux_hidden_states = model_output else: - hidden_states = output + hidden_states = model_output if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. @@ -1341,8 +1348,56 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], @@ -1813,6 +1868,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self), kv_cache_config.kv_cache_groups[0].kv_cache_spec, From 98ea35601cdb34fdd618f965e7bcc3cb02a677fc Mon Sep 17 00:00:00 2001 From: Jonathan Berkhahn Date: Mon, 12 May 2025 10:39:10 -0700 Subject: [PATCH 16/24] [Lora][Frontend]Add default local directory LoRA resolver plugin. (#16855) Signed-off-by: jberkhahn --- .buildkite/test-pipeline.yaml | 3 +- docs/source/features/lora.md | 7 +- pyproject.toml | 3 + tests/plugins/lora_resolvers/__init__.py | 0 .../test_filesystem_resolver.py | 65 +++++++++++++++++++ vllm/envs.py | 7 ++ vllm/plugins/lora_resolvers/README.md | 15 +++++ vllm/plugins/lora_resolvers/__init__.py | 0 .../lora_resolvers/filesystem_resolver.py | 49 ++++++++++++++ 9 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 tests/plugins/lora_resolvers/__init__.py create mode 100644 tests/plugins/lora_resolvers/test_filesystem_resolver.py create mode 100644 vllm/plugins/lora_resolvers/README.md create mode 100644 vllm/plugins/lora_resolvers/__init__.py create mode 100644 vllm/plugins/lora_resolvers/filesystem_resolver.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 027cb218df5e..9664615be85d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -628,7 +628,7 @@ steps: - vllm/plugins/ - tests/plugins/ commands: - # begin platform plugin tests, all the code in-between runs on dummy platform + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform - pip install -e ./plugins/vllm_add_dummy_platform - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y @@ -639,6 +639,7 @@ steps: - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min mirror_hardwares: [amdexperimental] diff --git a/docs/source/features/lora.md b/docs/source/features/lora.md index 85f03ba79087..5a3ce0c01f3f 100644 --- a/docs/source/features/lora.md +++ b/docs/source/features/lora.md @@ -159,9 +159,12 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. -You can either install existing plugins or implement your own. +You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers) +To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`, +it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and +that adapter will then be available for normal use on the server. -Steps to implement your own LoRAResolver plugin: +Alternatively, follow these example steps to implement your own plugin: 1. Implement the LoRAResolver interface. Example of a simple S3 LoRAResolver implementation: diff --git a/pyproject.toml b/pyproject.toml index 069e295bfb93..a26917a097c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ Slack="http://slack.vllm.ai/" [project.scripts] vllm = "vllm.entrypoints.cli.main:main" +[project.entry-points."vllm.general_plugins"] +lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver" + [tool.setuptools_scm] # no extra settings needed, presence enables setuptools-scm diff --git a/tests/plugins/lora_resolvers/__init__.py b/tests/plugins/lora_resolvers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/plugins/lora_resolvers/test_filesystem_resolver.py b/tests/plugins/lora_resolvers/test_filesystem_resolver.py new file mode 100644 index 000000000000..cb0f0c3c5fa6 --- /dev/null +++ b/tests/plugins/lora_resolvers/test_filesystem_resolver.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import shutil + +import pytest +from huggingface_hub import snapshot_download + +from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver + +MODEL_NAME = "mistralai/Mistral-7B-v0.1" +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" + + +@pytest.fixture(scope='module') +def adapter_cache(request, tmpdir_factory): + # Create dir that mimics the structure of the adapter cache + adapter_cache = tmpdir_factory.mktemp( + request.module.__name__) / "adapter_cache" + return adapter_cache + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.mark.asyncio +async def test_filesystem_resolver(adapter_cache, zephyr_lora_files): + model_files = adapter_cache / LORA_NAME + shutil.copytree(zephyr_lora_files, model_files) + + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + lora_request = await fs_resolver.resolve_lora(MODEL_NAME, LORA_NAME) + assert lora_request is not None + assert lora_request.lora_name == LORA_NAME + assert lora_request.lora_path == os.path.join(adapter_cache, LORA_NAME) + + +@pytest.mark.asyncio +async def test_missing_adapter(adapter_cache): + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + missing_lora_request = await fs_resolver.resolve_lora(MODEL_NAME, "foobar") + assert missing_lora_request is None + + +@pytest.mark.asyncio +async def test_nonlora_adapter(adapter_cache, pa_files): + model_files = adapter_cache / PA_NAME + shutil.copytree(pa_files, model_files) + + fs_resolver = FilesystemResolver(adapter_cache) + assert fs_resolver is not None + + pa_request = await fs_resolver.resolve_lora(MODEL_NAME, PA_NAME) + assert pa_request is None diff --git a/vllm/envs.py b/vllm/envs.py index b3faad03d345..0c742bf05623 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -68,6 +68,7 @@ VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[list[str]] = None + VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False @@ -503,6 +504,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ "VLLM_PLUGINS"].split(","), + # a local directory to look in for unrecognized LoRA adapters. + # only works if plugins are enabled and + # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. + "VLLM_LORA_RESOLVER_CACHE_DIR": + lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), + # Enables torch profiler if set. Path to the directory where torch profiler # traces are saved. Note that it must be an absolute path. "VLLM_TORCH_PROFILER_DIR": diff --git a/vllm/plugins/lora_resolvers/README.md b/vllm/plugins/lora_resolvers/README.md new file mode 100644 index 000000000000..7e7c55f5c69c --- /dev/null +++ b/vllm/plugins/lora_resolvers/README.md @@ -0,0 +1,15 @@ +# LoRA Resolver Plugins + +This directory contains vLLM general plugins for dynamically discovering and loading LoRA adapters +via the LoRAResolver plugin framework. + +Note that `VLLM_ALLOW_RUNTIME_LORA_UPDATING` must be set to true to allow LoRA resolver plugins +to work, and `VLLM_PLUGINS` must be set to include the desired resolver plugins. + +# lora_filesystem_resolver +This LoRA Resolver is installed with vLLM by default. +To use, set `VLLM_PLUGIN_LORA_CACHE_DIR` to a local directory. When vLLM receives a request +for a LoRA adapter `foobar` it doesn't currently recognize, it will look in that local directory +for a subdirectory `foobar` containing a LoRA adapter. If such an adapter exists, it will +load that adapter, and then service the request as normal. That adapter will then be available +for future requests as normal. diff --git a/vllm/plugins/lora_resolvers/__init__.py b/vllm/plugins/lora_resolvers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/plugins/lora_resolvers/filesystem_resolver.py b/vllm/plugins/lora_resolvers/filesystem_resolver.py new file mode 100644 index 000000000000..219231f77785 --- /dev/null +++ b/vllm/plugins/lora_resolvers/filesystem_resolver.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from typing import Optional + +import vllm.envs as envs +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry + + +class FilesystemResolver(LoRAResolver): + + def __init__(self, lora_cache_dir: str): + self.lora_cache_dir = lora_cache_dir + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + lora_path = os.path.join(self.lora_cache_dir, lora_name) + if os.path.exists(lora_path): + adapter_config_path = os.path.join(self.lora_cache_dir, lora_name, + "adapter_config.json") + if os.path.exists(adapter_config_path): + with open(adapter_config_path) as file: + adapter_config = json.load(file) + if adapter_config["peft_type"] == "LORA" and adapter_config[ + "base_model_name_or_path"] == base_model_name: + lora_request = LoRARequest(lora_name=lora_name, + lora_int_id=abs( + hash(lora_name)), + lora_path=lora_path) + return lora_request + return None + + +def register_filesystem_resolver(): + """Register the filesystem LoRA Resolver with vLLM""" + + lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR + if lora_cache_dir: + if not os.path.exists(lora_cache_dir) or not os.path.isdir( + lora_cache_dir): + raise ValueError( + "VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \ + for Filesystem Resolver plugin to function") + fs_resolver = FilesystemResolver(lora_cache_dir) + LoRAResolverRegistry.register_resolver("Filesystem Resolver", + fs_resolver) + + return From 72a3f6b898d8397c406debc49593e15aa7cbb4bc Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 19:25:33 +0100 Subject: [PATCH 17/24] Construct `KVTransferConfig` properly from Python instead of using JSON blobs without CLI (#17994) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- examples/lmcache/disagg_prefill_lmcache_v0.py | 14 +++++++------ .../lmcache/kv_cache_sharing_lmcache_v1.py | 8 +++---- .../decode_example.py | 21 ++++++++++--------- .../prefill_example.py | 11 +++++----- .../disaggregated_prefill.py | 14 +++++++------ 5 files changed, 37 insertions(+), 31 deletions(-) diff --git a/examples/lmcache/disagg_prefill_lmcache_v0.py b/examples/lmcache/disagg_prefill_lmcache_v0.py index 7da6fb7aaa23..66cc94185230 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v0.py +++ b/examples/lmcache/disagg_prefill_lmcache_v0.py @@ -49,9 +49,10 @@ def run_prefill(prefill_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="LMCacheConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", @@ -78,9 +79,10 @@ def run_decode(prefill_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="LMCacheConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py index af1b4351dd54..7748f8ca6133 100644 --- a/examples/lmcache/kv_cache_sharing_lmcache_v1.py +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -49,8 +49,8 @@ def run_store(store_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", + kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", @@ -76,8 +76,8 @@ def run_retrieve(store_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", + kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 66efbc0c9dee..11918f72feec 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -16,16 +16,17 @@ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) -llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - max_num_batched_tokens=64, - max_num_seqs=16, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' - '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' - )) #, max_model_len=2048, max_num_batched_tokens=2048) +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) # 1ST generation (prefill instance) outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index f7cbf6557d54..798128301e0f 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -17,11 +17,12 @@ llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, gpu_memory_utilization=0.8, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' - '"kv_connector_extra_config": ' - '{"shared_storage_path": "local_storage"}}') - ) #, max_model_len=2048, max_num_batched_tokens=2048) + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) # 1ST generation (prefill instance) outputs = llm.generate( diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index d60985146c5c..bb6fdd48f79e 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -32,9 +32,10 @@ def run_prefill(prefill_done): # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="PyNcclConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. @@ -71,9 +72,10 @@ def run_decode(prefill_done): # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="PyNcclConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. From b9fd0d7a6984bd1b6090f564660c9d1706490700 Mon Sep 17 00:00:00 2001 From: Carol Zheng Date: Mon, 12 May 2025 12:06:59 -0700 Subject: [PATCH 18/24] [CI/Build] Fix TPU V1 Test mixed use of & and && across tests (#17968) --- .../scripts/hardware_ci/run-tpu-v1-test.sh | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 939daddad92b..2d375d7e9d87 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \ && tpu-info \ && { \ echo TEST_0: Running test_perf.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ echo TEST_0_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_1: Running test_compilation.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ echo TEST_1_EXIT_CODE: \$?; \ } & \ { \ echo TEST_2: Running test_basic.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ echo TEST_2_EXIT_CODE: \$?; \ } & \ { \ echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ - pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ echo TEST_3_EXIT_CODE: \$?; \ } & \ { \ echo TEST_4: Running test_quantization_accuracy.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ echo TEST_4_EXIT_CODE: \$?; \ } & \ { \ @@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \ } & \ { \ echo TEST_6: Running test_tpu_model_runner.py; \ - pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ echo TEST_6_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_7: Running test_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ echo TEST_7_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_8: Running test_topk_topp_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ echo TEST_8_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_9: Running test_multimodal.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ echo TEST_9_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_10: Running test_pallas.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ echo TEST_10_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_11: Running test_struct_output_generate.py; \ - pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ echo TEST_11_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_12: Running test_moe_pallas.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ echo TEST_12_EXIT_CODE: \$?; \ } & \ # Disable the TPU LoRA tests until the feature is activated - # && { \ + # & { \ # echo TEST_13: Running test_moe_pallas.py; \ - # pytest -s -v /workspace/vllm/tests/tpu/lora/; \ + # python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \ # echo TEST_13_EXIT_CODE: \$?; \ # } & \ wait \ From 289199feb6616a27d97d15cde470772e1585bb84 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 13 May 2025 03:09:16 +0800 Subject: [PATCH 19/24] [Core] Use platform-agnostic device control for DP engine core (#17245) Signed-off-by: Jade Zheng --- vllm/platforms/cuda.py | 26 ++++---------------------- vllm/platforms/interface.py | 19 +++++++++++++++++++ vllm/platforms/rocm.py | 11 +---------- vllm/v1/engine/core.py | 13 ++++++------- 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f116285870ec..dd3a54f7daf2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -34,24 +34,6 @@ torch.backends.cuda.enable_cudnn_sdp(False) -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - if device_ids == [""]: - msg = ( - "CUDA_VISIBLE_DEVICES is set to empty string, which means" - " GPU support is disabled. If you are using ray, please unset" - " the environment variable `CUDA_VISIBLE_DEVICES` inside the" - " worker/actor. " - "Check https://github.com/vllm-project/vllm/issues/8402 for" - " more information.") - raise RuntimeError(msg) - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) @@ -338,7 +320,7 @@ def get_device_capability(cls, device_id: int = 0 ) -> Optional[DeviceCapability]: try: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return DeviceCapability(major=major, minor=minor) @@ -360,20 +342,20 @@ def has_device_capability( @classmethod @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) @classmethod @with_nvml_context def get_device_uuid(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return pynvml.nvmlDeviceGetUUID(handle) @classmethod @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 68b90796ece2..a0c9e2ae374d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import os import platform import random from platform import uname @@ -161,6 +162,24 @@ def is_cuda_alike(self) -> bool: def is_sleep_mode_available(self) -> bool: return self._enum == PlatformEnum.CUDA + @classmethod + def device_id_to_physical_device_id(cls, device_id: int): + if cls.device_control_env_var in os.environ: + device_ids = os.environ[cls.device_control_env_var].split(",") + if device_ids == [""]: + msg = (f"{cls.device_control_env_var} is set to empty string, " + "which means current platform support is disabled. If " + "you are using ray, please unset the environment " + f"variable `{cls.device_control_env_var}` inside the " + "worker/actor. Check " + "https://github.com/vllm-project/vllm/issues/8402 for " + "more information.") + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ea028e13fc4d..f3d64f01b0f7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -95,15 +95,6 @@ def wrapper(*args, **kwargs): return wrapper -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - @cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName @@ -238,7 +229,7 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool: @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = amdsmi_get_processor_handles()[physical_device_id] asic_info = amdsmi_get_gpu_asic_info(handle) device_name: str = asic_info["device_id"] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c1aa0ce27d3f..fde60bbfa51f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -622,13 +622,12 @@ def __init__( assert 0 <= local_dp_rank <= dp_rank < dp_size from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - from vllm.platforms.cuda import device_id_to_physical_device_id - tp_size = vllm_config.parallel_config.tensor_parallel_size - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - str(device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) + device_control_env_var = current_platform.device_control_env_var + tp_size = vllm_config.parallel_config.tensor_parallel_size + os.environ[device_control_env_var] = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * + tp_size)) self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() From e9c730c9bd0fda1056581bc4cf018871e64fb966 Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Mon, 12 May 2025 15:05:33 -0500 Subject: [PATCH 20/24] Enabling "Weight Loading Multiple GPU Test - Large Models" (#18020) --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9664615be85d..6900efdcf937 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -703,6 +703,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 From 302f3aca7ea3f57842881cb2ae0062c19ad24758 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 13 May 2025 04:46:12 +0800 Subject: [PATCH 21/24] [v1][KVCacheManager] Change prefix caching metric from counting blocks to counting tokens (#18003) Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 12 ++++++------ vllm/v1/metrics/loggers.py | 4 ++-- vllm/v1/metrics/stats.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 27368374ea8d..d0e922363c27 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -161,11 +161,15 @@ def get_computed_blocks(self, computed_blocks = ( self.single_type_manager.find_longest_cache_hit(block_hashes)) + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size if self.log_stats: assert self.prefix_cache_stats is not None - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) + self.prefix_cache_stats.queries += request.num_tokens + self.prefix_cache_stats.hits += num_computed_tokens if last_block_hash is not None: # Add back the last block hash if it was removed. @@ -173,10 +177,6 @@ def get_computed_blocks(self, # we shouldn't modify it directly. block_hashes.append(last_block_hash) - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 7455f1813cd7..6ee40850beb1 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -183,13 +183,13 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( name="vllm:gpu_prefix_cache_queries", documentation= - "GPU prefix cache queries, in terms of number of queried blocks.", + "GPU prefix cache queries, in terms of number of queried tokens.", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( name="vllm:gpu_prefix_cache_hits", documentation= - "GPU prefix cache hits, in terms of number of cached blocks.", + "GPU prefix cache hits, in terms of number of cached tokens.", labelnames=labelnames).labels(*labelvalues) # diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index fd949264885b..8fe1630616a4 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -19,7 +19,7 @@ class PrefixCacheStats: # The number of requests in this update. requests: int = 0 # The number of queries in these requests. Note that "queries" here - # means the number of blocks that were queried from the cache. + # means the number of tokens that were queried from the cache. queries: int = 0 # The number of hits in these requests. hits: int = 0 From 195adb47c0f181d32856aed49fad9973c3013217 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 12 May 2025 16:59:47 -0400 Subject: [PATCH 22/24] [Chore] Remove unused method (#18024) Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d0e922363c27..b34b53155cc3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -376,8 +376,3 @@ def get_block_ids(self, request_id: str) -> list[int]: block.block_id for block in self.single_type_manager.req_to_blocks[request_id] ] - - def get_num_blocks(self, request_id: str): - """Get the number of blocks.""" - assert request_id in self.single_type_manager.req_to_blocks - return len(self.single_type_manager.req_to_blocks[request_id]) From 2b0db9b0e2378dbb7c44dd17e4066b1f42d42b70 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 12 May 2025 14:00:04 -0700 Subject: [PATCH 23/24] Enable standard language model for torhc nightly (#18004) Signed-off-by: Yang Wang --- .buildkite/test-pipeline.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6900efdcf937..da5db189f70a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -472,12 +472,14 @@ steps: - label: Language Models Test (Standard) mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model - label: Language Models Test (Extended) @@ -493,11 +495,13 @@ steps: - label: Multi-Modal Models Test (Standard) mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal/processing - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work From ebab1ac37c8efbb29ce052044b1a73ab20b2ea62 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 12 May 2025 18:31:54 -0400 Subject: [PATCH 24/24] [CI] Make JSON output tests less likely to fail (#17859) Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 12 ++++++--- .../llm/test_struct_output_generate.py | 26 ++++++++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index d84b2b22db12..bdee0bb8da7d 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -72,12 +72,14 @@ def sample_json_schema(): "type": "string" } }, - "required": ["company", "duration", "position"] + "required": ["company", "duration", "position"], + "additionalProperties": False } } }, "required": - ["name", "age", "skills", "grade", "email", "work_history"] + ["name", "age", "skills", "grade", "email", "work_history"], + "additionalProperties": False } @@ -100,7 +102,8 @@ def unsupported_json_schema(): } } }, - "required": ["score", "tags"] + "required": ["score", "tags"], + "additionalProperties": False } @@ -139,7 +142,8 @@ def sample_definition_json_schema(): }, 'required': ['steps', 'final_answer'], 'title': 'MathReasoning', - 'type': 'object' + 'type': 'object', + "additionalProperties": False } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 81601c87ad8b..5c116598ff3f 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -62,6 +62,16 @@ class CarDescription(BaseModel): car_type: CarType +def _load_json(s: str, backend: str) -> str: + if backend != "xgrammar": + return json.loads(s) + + # xgrammar specific workarounds + # https://github.com/mlc-ai/xgrammar/issues/286 + s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s) + return json.loads(s) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", @@ -102,7 +112,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ (f"Give an example JSON for an employee profile that fits this " @@ -131,7 +141,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=100, + max_tokens=4096, n=2, guided_decoding=GuidedDecodingParams(json_object=True)) @@ -161,7 +171,7 @@ def test_structured_output( # sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) if guided_decoding_backend.startswith("xgrammar"): with pytest.raises(ValueError, @@ -376,12 +386,13 @@ def test_structured_output( "minLength": min_length } }, - "required": ["description"] + "required": ["description"], + "additionalProperties": False } sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( @@ -417,7 +428,8 @@ def test_structured_output( "city": { "type": "string" } - } + }, + "additionalProperties": False }, "end": "" }], @@ -426,7 +438,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.0, - max_tokens=100, + max_tokens=4096, guided_decoding=GuidedDecodingParams( structural_tag=json.dumps(structural_tag_config)))