|
54 | 54 | from transformers.activations import GELUActivation
|
55 | 55 |
|
56 | 56 | from vllm.config import VllmConfig
|
57 |
| -from vllm.distributed import (get_tensor_model_parallel_rank, |
58 |
| - get_tensor_model_parallel_world_size) |
| 57 | +from vllm.distributed import get_pp_group |
59 | 58 | from vllm.model_executor.layers.fused_moe import FusedMoE
|
60 | 59 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
61 | 60 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
62 | 61 | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
63 | 62 | from vllm.model_executor.model_loader.weight_utils import (
|
64 | 63 | default_weight_loader, maybe_remap_kv_scale_name)
|
65 | 64 | from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
|
66 |
| -from vllm.model_executor.models.interfaces import SupportsMultiModal |
| 65 | +from vllm.model_executor.models.interfaces import (SupportsMultiModal, |
| 66 | + SupportsPP) |
67 | 67 | from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
|
68 | 68 | from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
69 | 69 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
81 | 81 | from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
82 | 82 | from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
83 | 83 |
|
84 |
| -from .utils import is_pp_missing_parameter, maybe_prefix |
| 84 | +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix |
85 | 85 |
|
86 | 86 |
|
87 | 87 | # For dummy input only
|
@@ -270,7 +270,8 @@ def get_replacement(item_idx: int):
|
270 | 270 | @MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
|
271 | 271 | info=KimiVLProcessingInfo,
|
272 | 272 | dummy_inputs=KimiVLDummyInputsBuilder)
|
273 |
| -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): |
| 273 | +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, |
| 274 | + SupportsPP): |
274 | 275 |
|
275 | 276 | @classmethod
|
276 | 277 | def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
@@ -304,17 +305,21 @@ def __init__(
|
304 | 305 | prefix=maybe_prefix(prefix, "language_model"),
|
305 | 306 | )
|
306 | 307 | self.unpadded_vocab_size = config.text_config.vocab_size
|
307 |
| - self.lm_head = ParallelLMHead( |
308 |
| - self.unpadded_vocab_size, |
309 |
| - config.text_config.hidden_size, |
310 |
| - org_num_embeddings=self.config.text_config.vocab_size, |
311 |
| - padding_size=DEFAULT_VOCAB_PADDING_SIZE) |
| 308 | + if get_pp_group().is_last_rank: |
| 309 | + self.lm_head = ParallelLMHead( |
| 310 | + self.unpadded_vocab_size, |
| 311 | + config.text_config.hidden_size, |
| 312 | + org_num_embeddings=self.config.text_config.vocab_size, |
| 313 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE, |
| 314 | + ) |
| 315 | + else: |
| 316 | + self.lm_head = PPMissingLayer() |
| 317 | + self.make_empty_intermediate_tensors = ( |
| 318 | + self.language_model.make_empty_intermediate_tensors) |
312 | 319 | logit_scale = getattr(config, "logit_scale", 1.0)
|
313 | 320 | self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
314 | 321 | config.vocab_size, logit_scale)
|
315 | 322 | self.media_placeholder: int = self.config.media_placeholder_token_id
|
316 |
| - self.tp_rank = get_tensor_model_parallel_rank() |
317 |
| - self.tp_world_size = get_tensor_model_parallel_world_size() |
318 | 323 |
|
319 | 324 | # ref: qwen2_vl.py
|
320 | 325 | def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
0 commit comments