|
38 | 38 |
|
39 | 39 | from vllm.config import VllmConfig
|
40 | 40 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 41 | +from vllm.model_executor.layers.quantization.awq import AWQConfig |
| 42 | +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig |
41 | 43 | from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
42 | 44 | get_2d_sincos_pos_embed)
|
43 | 45 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
@@ -339,7 +341,9 @@ def get_model_version(self):
|
339 | 341 |
|
340 | 342 | def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
341 | 343 | mm_limits = {"image": None}
|
342 |
| - if self.get_model_version() == (2, 6): |
| 344 | + if self.get_model_version() == (2, |
| 345 | + 6) or self.get_model_version() == (4, |
| 346 | + 0): |
343 | 347 | mm_limits["video"] = None
|
344 | 348 |
|
345 | 349 | return mm_limits
|
@@ -620,7 +624,8 @@ def _base_call_hf_processor(
|
620 | 624 | out_keys: set[str],
|
621 | 625 | ) -> dict[str, NestedTensors]:
|
622 | 626 | # This processor supports zipping prompt and mm_data together
|
623 |
| - if self.info.get_model_version() == (2, 6): |
| 627 | + if self.info.get_model_version() == ( |
| 628 | + 2, 6) or self.info.get_model_version() == (4, 0): |
624 | 629 | inputs = super()._call_hf_processor(
|
625 | 630 | prompt=prompts, # type: ignore
|
626 | 631 | mm_data=mm_data,
|
@@ -679,10 +684,18 @@ def _get_prompt_updates(
|
679 | 684 | hf_processor_mm_kwargs: Mapping[str, object],
|
680 | 685 | out_mm_kwargs: MultiModalKwargs,
|
681 | 686 | ) -> Sequence[PromptUpdate]:
|
682 |
| - placeholder = { |
683 |
| - "image": self.info.image_pattern, |
684 |
| - "video": self.info.video_pattern, |
685 |
| - } |
| 687 | + placeholders = [("image", self.info.image_pattern), |
| 688 | + ("video", self.info.video_pattern)] |
| 689 | + |
| 690 | + # hard code for inconsistency of encode-decode image_pattern |
| 691 | + additional_placeholders = [] |
| 692 | + tokenizer = self.info.get_tokenizer() |
| 693 | + for modality, pattern in placeholders: |
| 694 | + sub_pattern = tokenizer.decode( |
| 695 | + tokenizer.encode(pattern, add_special_tokens=False)) |
| 696 | + if sub_pattern != pattern: |
| 697 | + additional_placeholders.append((modality, sub_pattern)) |
| 698 | + placeholders += additional_placeholders |
686 | 699 |
|
687 | 700 | def get_image_replacement(item_idx: int):
|
688 | 701 | images = mm_items.get_items(
|
@@ -714,9 +727,9 @@ def get_video_replacement(item_idx: int):
|
714 | 727 |
|
715 | 728 | return [
|
716 | 729 | PromptReplacement(modality=modality,
|
717 |
| - target=placeholder[modality], |
| 730 | + target=pattern, |
718 | 731 | replacement=get_replacement[modality])
|
719 |
| - for modality in ("image", "video") |
| 732 | + for modality, pattern in placeholders |
720 | 733 | ]
|
721 | 734 |
|
722 | 735 | def _get_mm_fields_config(
|
@@ -1262,11 +1275,124 @@ def get_vision_hidden_states(
|
1262 | 1275 |
|
1263 | 1276 | return self.resampler(vision_embedding, tgt_sizes)
|
1264 | 1277 |
|
| 1278 | + def load_weights(self, weights: Iterable[tuple[str, |
| 1279 | + torch.Tensor]]) -> set[str]: |
| 1280 | + loader = AutoWeightsLoader(self, |
| 1281 | + skip_prefixes=["apm.", "audio", "tts"]) |
| 1282 | + return loader.load_weights(weights) |
| 1283 | + |
| 1284 | + |
| 1285 | +class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): |
| 1286 | + packed_modules_mapping = { |
| 1287 | + "qkv_proj": [ |
| 1288 | + "q_proj", |
| 1289 | + "k_proj", |
| 1290 | + "v_proj", |
| 1291 | + ], |
| 1292 | + "gate_up_proj": [ |
| 1293 | + "gate_proj", |
| 1294 | + "up_proj", |
| 1295 | + ], |
| 1296 | + } |
| 1297 | + |
| 1298 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 1299 | + super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 1300 | + assert self.version == (4, 0) |
| 1301 | + |
| 1302 | + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): |
| 1303 | + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): |
| 1304 | + return None |
| 1305 | + return quant_config |
| 1306 | + |
| 1307 | + def init_llm( |
| 1308 | + self, |
| 1309 | + vllm_config: VllmConfig, |
| 1310 | + prefix: str = "", |
| 1311 | + ) -> nn.Module: |
| 1312 | + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) |
| 1313 | + |
| 1314 | + def init_vision_module( |
| 1315 | + self, |
| 1316 | + config: PretrainedConfig, |
| 1317 | + quant_config: Optional[QuantizationConfig] = None, |
| 1318 | + prefix: str = "", |
| 1319 | + ) -> nn.Module: |
| 1320 | + quant_config = self._maybe_ignore_quant_config(quant_config) |
| 1321 | + model = Idefics2VisionTransformer(config.vision_config, |
| 1322 | + quant_config=quant_config, |
| 1323 | + prefix=prefix) |
| 1324 | + if self.config.drop_vision_last_layer: |
| 1325 | + model.encoder.layers = model.encoder.layers[:-1] |
| 1326 | + return model |
| 1327 | + |
| 1328 | + def init_resampler( |
| 1329 | + self, |
| 1330 | + embed_dim: int, |
| 1331 | + vision_dim: int, |
| 1332 | + quant_config: Optional[QuantizationConfig] = None, |
| 1333 | + prefix: str = "", |
| 1334 | + ) -> nn.Module: |
| 1335 | + quant_config = self._maybe_ignore_quant_config(quant_config) |
| 1336 | + with set_default_torch_dtype(torch.float16): |
| 1337 | + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. |
| 1338 | + resampler = Resampler2_5(num_queries=self.config.query_num, |
| 1339 | + embed_dim=embed_dim, |
| 1340 | + num_heads=embed_dim // 128, |
| 1341 | + kv_dim=vision_dim, |
| 1342 | + quant_config=quant_config, |
| 1343 | + prefix=prefix) |
| 1344 | + |
| 1345 | + return resampler.to(device=current_platform.device_type, |
| 1346 | + dtype=torch.get_default_dtype()) |
| 1347 | + |
| 1348 | + def get_vision_hidden_states( |
| 1349 | + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: |
| 1350 | + pixel_values = data["pixel_values"] |
| 1351 | + tgt_sizes = data["tgt_sizes"] |
| 1352 | + |
| 1353 | + B = len(pixel_values) |
| 1354 | + P = pixel_values[0].shape[-2] |
| 1355 | + L = max(item.shape[-1] for item in pixel_values) |
| 1356 | + device = pixel_values[0].device |
| 1357 | + dtype = pixel_values[0].dtype |
| 1358 | + |
| 1359 | + all_pixel_values = torch.zeros((B, 3, P, L), |
| 1360 | + dtype=dtype, |
| 1361 | + device=device) |
| 1362 | + for i, pixel_values_item in enumerate(pixel_values): |
| 1363 | + L_item = pixel_values_item.shape[-1] |
| 1364 | + all_pixel_values[i, ..., :L_item] = pixel_values_item |
| 1365 | + |
| 1366 | + num_patches = tgt_sizes.prod(-1) |
| 1367 | + max_patches = num_patches.max().item() |
| 1368 | + assert isinstance(max_patches, int) |
| 1369 | + |
| 1370 | + patch_attn_mask = torch.zeros((B, max_patches), |
| 1371 | + dtype=torch.bool, |
| 1372 | + device=device) |
| 1373 | + for i, num_patches_item in enumerate(num_patches): |
| 1374 | + patch_attn_mask[i, :num_patches_item] = True |
| 1375 | + |
| 1376 | + vision_embedding = self.vpm( |
| 1377 | + all_pixel_values, |
| 1378 | + patch_attention_mask=patch_attn_mask.unsqueeze(1), |
| 1379 | + tgt_sizes=tgt_sizes, |
| 1380 | + ) |
| 1381 | + |
| 1382 | + return self.resampler(vision_embedding, tgt_sizes) |
| 1383 | + |
| 1384 | + def load_weights(self, weights: Iterable[tuple[str, |
| 1385 | + torch.Tensor]]) -> set[str]: |
| 1386 | + loader = AutoWeightsLoader(self, |
| 1387 | + skip_prefixes=["apm.", "audio", "tts"]) |
| 1388 | + return loader.load_weights(weights) |
| 1389 | + |
1265 | 1390 |
|
1266 | 1391 | _SUPPORT_VERSION = {
|
1267 | 1392 | (2, 0): MiniCPMV2_0,
|
1268 | 1393 | (2, 5): MiniCPMV2_5,
|
1269 | 1394 | (2, 6): MiniCPMV2_6,
|
| 1395 | + (4, 0): MiniCPMV4_0, |
1270 | 1396 | }
|
1271 | 1397 |
|
1272 | 1398 |
|
@@ -1294,8 +1420,10 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
|
1294 | 1420 | # Dispatch class based on version
|
1295 | 1421 | instance_cls = _SUPPORT_VERSION.get(version)
|
1296 | 1422 | if instance_cls is None:
|
1297 |
| - raise ValueError( |
1298 |
| - "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") |
| 1423 | + supported_versions = ", ".join( |
| 1424 | + [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())]) |
| 1425 | + raise ValueError(f"Currently, MiniCPMV only supports versions " |
| 1426 | + f"{supported_versions}. Got version: {version}") |
1299 | 1427 |
|
1300 | 1428 | # quant_config references base class members,
|
1301 | 1429 | # so update values before init is called
|
|
0 commit comments