Skip to content

Commit fda9537

Browse files
[Model] Support Pipeline Parallelism for moonshotai/Kimi-VL-A3B-Thinking-2506 (#23114)
Signed-off-by: zjy0516 <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 90bbe0a commit fda9537

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
626626
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
627627
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
628628
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
629-
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
629+
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
630630
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
631631
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
632632
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |

vllm/model_executor/models/kimi_vl.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@
5454
from transformers.activations import GELUActivation
5555

5656
from vllm.config import VllmConfig
57-
from vllm.distributed import (get_tensor_model_parallel_rank,
58-
get_tensor_model_parallel_world_size)
57+
from vllm.distributed import get_pp_group
5958
from vllm.model_executor.layers.fused_moe import FusedMoE
6059
from vllm.model_executor.layers.logits_processor import LogitsProcessor
6160
from vllm.model_executor.layers.vocab_parallel_embedding import (
6261
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
6362
from vllm.model_executor.model_loader.weight_utils import (
6463
default_weight_loader, maybe_remap_kv_scale_name)
6564
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
66-
from vllm.model_executor.models.interfaces import SupportsMultiModal
65+
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
66+
SupportsPP)
6767
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
6868
from vllm.model_executor.models.utils import merge_multimodal_embeddings
6969
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -81,7 +81,7 @@
8181
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
8282
from vllm.utils.tensor_schema import TensorSchema, TensorShape
8383

84-
from .utils import is_pp_missing_parameter, maybe_prefix
84+
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
8585

8686

8787
# For dummy input only
@@ -270,7 +270,8 @@ def get_replacement(item_idx: int):
270270
@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
271271
info=KimiVLProcessingInfo,
272272
dummy_inputs=KimiVLDummyInputsBuilder)
273-
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
273+
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
274+
SupportsPP):
274275

275276
@classmethod
276277
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -304,17 +305,21 @@ def __init__(
304305
prefix=maybe_prefix(prefix, "language_model"),
305306
)
306307
self.unpadded_vocab_size = config.text_config.vocab_size
307-
self.lm_head = ParallelLMHead(
308-
self.unpadded_vocab_size,
309-
config.text_config.hidden_size,
310-
org_num_embeddings=self.config.text_config.vocab_size,
311-
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
308+
if get_pp_group().is_last_rank:
309+
self.lm_head = ParallelLMHead(
310+
self.unpadded_vocab_size,
311+
config.text_config.hidden_size,
312+
org_num_embeddings=self.config.text_config.vocab_size,
313+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
314+
)
315+
else:
316+
self.lm_head = PPMissingLayer()
317+
self.make_empty_intermediate_tensors = (
318+
self.language_model.make_empty_intermediate_tensors)
312319
logit_scale = getattr(config, "logit_scale", 1.0)
313320
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
314321
config.vocab_size, logit_scale)
315322
self.media_placeholder: int = self.config.media_placeholder_token_id
316-
self.tp_rank = get_tensor_model_parallel_rank()
317-
self.tp_world_size = get_tensor_model_parallel_world_size()
318323

319324
# ref: qwen2_vl.py
320325
def _validate_and_reshape_mm_tensor(self, mm_input: object,

0 commit comments

Comments
 (0)