Skip to content

Commit 8c14bcd

Browse files
tc-mbimning3
authored andcommitted
[model] Support MiniCPM-V 4.0 (vllm-project#22166)
Co-authored-by: imning3 <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 40e0d79 commit 8c14bcd

File tree

3 files changed

+140
-12
lines changed

3 files changed

+140
-12
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
622622
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
623623
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
624624
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
625-
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
625+
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
626626
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
627627
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
628628
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def check_available_online(
427427
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
428428
trust_remote_code=True),
429429
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
430-
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
430+
extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501
431431
trust_remote_code=True),
432432
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
433433
trust_remote_code=True,

vllm/model_executor/models/minicpmv.py

Lines changed: 138 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
from vllm.config import VllmConfig
4040
from vllm.model_executor.layers.quantization import QuantizationConfig
41+
from vllm.model_executor.layers.quantization.awq import AWQConfig
42+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
4143
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
4244
get_2d_sincos_pos_embed)
4345
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@@ -339,7 +341,9 @@ def get_model_version(self):
339341

340342
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
341343
mm_limits = {"image": None}
342-
if self.get_model_version() == (2, 6):
344+
if self.get_model_version() == (2,
345+
6) or self.get_model_version() == (4,
346+
0):
343347
mm_limits["video"] = None
344348

345349
return mm_limits
@@ -620,7 +624,8 @@ def _base_call_hf_processor(
620624
out_keys: set[str],
621625
) -> dict[str, NestedTensors]:
622626
# This processor supports zipping prompt and mm_data together
623-
if self.info.get_model_version() == (2, 6):
627+
if self.info.get_model_version() == (
628+
2, 6) or self.info.get_model_version() == (4, 0):
624629
inputs = super()._call_hf_processor(
625630
prompt=prompts, # type: ignore
626631
mm_data=mm_data,
@@ -679,10 +684,18 @@ def _get_prompt_updates(
679684
hf_processor_mm_kwargs: Mapping[str, object],
680685
out_mm_kwargs: MultiModalKwargs,
681686
) -> Sequence[PromptUpdate]:
682-
placeholder = {
683-
"image": self.info.image_pattern,
684-
"video": self.info.video_pattern,
685-
}
687+
placeholders = [("image", self.info.image_pattern),
688+
("video", self.info.video_pattern)]
689+
690+
# hard code for inconsistency of encode-decode image_pattern
691+
additional_placeholders = []
692+
tokenizer = self.info.get_tokenizer()
693+
for modality, pattern in placeholders:
694+
sub_pattern = tokenizer.decode(
695+
tokenizer.encode(pattern, add_special_tokens=False))
696+
if sub_pattern != pattern:
697+
additional_placeholders.append((modality, sub_pattern))
698+
placeholders += additional_placeholders
686699

687700
def get_image_replacement(item_idx: int):
688701
images = mm_items.get_items(
@@ -714,9 +727,9 @@ def get_video_replacement(item_idx: int):
714727

715728
return [
716729
PromptReplacement(modality=modality,
717-
target=placeholder[modality],
730+
target=pattern,
718731
replacement=get_replacement[modality])
719-
for modality in ("image", "video")
732+
for modality, pattern in placeholders
720733
]
721734

722735
def _get_mm_fields_config(
@@ -1262,11 +1275,124 @@ def get_vision_hidden_states(
12621275

12631276
return self.resampler(vision_embedding, tgt_sizes)
12641277

1278+
def load_weights(self, weights: Iterable[tuple[str,
1279+
torch.Tensor]]) -> set[str]:
1280+
loader = AutoWeightsLoader(self,
1281+
skip_prefixes=["apm.", "audio", "tts"])
1282+
return loader.load_weights(weights)
1283+
1284+
1285+
class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
1286+
packed_modules_mapping = {
1287+
"qkv_proj": [
1288+
"q_proj",
1289+
"k_proj",
1290+
"v_proj",
1291+
],
1292+
"gate_up_proj": [
1293+
"gate_proj",
1294+
"up_proj",
1295+
],
1296+
}
1297+
1298+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1299+
super().__init__(vllm_config=vllm_config, prefix=prefix)
1300+
assert self.version == (4, 0)
1301+
1302+
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
1303+
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
1304+
return None
1305+
return quant_config
1306+
1307+
def init_llm(
1308+
self,
1309+
vllm_config: VllmConfig,
1310+
prefix: str = "",
1311+
) -> nn.Module:
1312+
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
1313+
1314+
def init_vision_module(
1315+
self,
1316+
config: PretrainedConfig,
1317+
quant_config: Optional[QuantizationConfig] = None,
1318+
prefix: str = "",
1319+
) -> nn.Module:
1320+
quant_config = self._maybe_ignore_quant_config(quant_config)
1321+
model = Idefics2VisionTransformer(config.vision_config,
1322+
quant_config=quant_config,
1323+
prefix=prefix)
1324+
if self.config.drop_vision_last_layer:
1325+
model.encoder.layers = model.encoder.layers[:-1]
1326+
return model
1327+
1328+
def init_resampler(
1329+
self,
1330+
embed_dim: int,
1331+
vision_dim: int,
1332+
quant_config: Optional[QuantizationConfig] = None,
1333+
prefix: str = "",
1334+
) -> nn.Module:
1335+
quant_config = self._maybe_ignore_quant_config(quant_config)
1336+
with set_default_torch_dtype(torch.float16):
1337+
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
1338+
resampler = Resampler2_5(num_queries=self.config.query_num,
1339+
embed_dim=embed_dim,
1340+
num_heads=embed_dim // 128,
1341+
kv_dim=vision_dim,
1342+
quant_config=quant_config,
1343+
prefix=prefix)
1344+
1345+
return resampler.to(device=current_platform.device_type,
1346+
dtype=torch.get_default_dtype())
1347+
1348+
def get_vision_hidden_states(
1349+
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
1350+
pixel_values = data["pixel_values"]
1351+
tgt_sizes = data["tgt_sizes"]
1352+
1353+
B = len(pixel_values)
1354+
P = pixel_values[0].shape[-2]
1355+
L = max(item.shape[-1] for item in pixel_values)
1356+
device = pixel_values[0].device
1357+
dtype = pixel_values[0].dtype
1358+
1359+
all_pixel_values = torch.zeros((B, 3, P, L),
1360+
dtype=dtype,
1361+
device=device)
1362+
for i, pixel_values_item in enumerate(pixel_values):
1363+
L_item = pixel_values_item.shape[-1]
1364+
all_pixel_values[i, ..., :L_item] = pixel_values_item
1365+
1366+
num_patches = tgt_sizes.prod(-1)
1367+
max_patches = num_patches.max().item()
1368+
assert isinstance(max_patches, int)
1369+
1370+
patch_attn_mask = torch.zeros((B, max_patches),
1371+
dtype=torch.bool,
1372+
device=device)
1373+
for i, num_patches_item in enumerate(num_patches):
1374+
patch_attn_mask[i, :num_patches_item] = True
1375+
1376+
vision_embedding = self.vpm(
1377+
all_pixel_values,
1378+
patch_attention_mask=patch_attn_mask.unsqueeze(1),
1379+
tgt_sizes=tgt_sizes,
1380+
)
1381+
1382+
return self.resampler(vision_embedding, tgt_sizes)
1383+
1384+
def load_weights(self, weights: Iterable[tuple[str,
1385+
torch.Tensor]]) -> set[str]:
1386+
loader = AutoWeightsLoader(self,
1387+
skip_prefixes=["apm.", "audio", "tts"])
1388+
return loader.load_weights(weights)
1389+
12651390

12661391
_SUPPORT_VERSION = {
12671392
(2, 0): MiniCPMV2_0,
12681393
(2, 5): MiniCPMV2_5,
12691394
(2, 6): MiniCPMV2_6,
1395+
(4, 0): MiniCPMV4_0,
12701396
}
12711397

12721398

@@ -1294,8 +1420,10 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
12941420
# Dispatch class based on version
12951421
instance_cls = _SUPPORT_VERSION.get(version)
12961422
if instance_cls is None:
1297-
raise ValueError(
1298-
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
1423+
supported_versions = ", ".join(
1424+
[f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())])
1425+
raise ValueError(f"Currently, MiniCPMV only supports versions "
1426+
f"{supported_versions}. Got version: {version}")
12991427

13001428
# quant_config references base class members,
13011429
# so update values before init is called

0 commit comments

Comments
 (0)