diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index ec736aa236ff..1ef1e2aa5344 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -535,7 +535,7 @@ Specified using `--task generate`.
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 825abeaf7e75..f088309a644d 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -283,6 +283,41 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
prompts=prompts,
)
+# GLM-4.5V
+def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "zai-org/GLM-4.5V"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=2,
+ mm_processor_kwargs={
+ "size": {"shortest_edge": 12544, "longest_edge": 47040000},
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ tensor_parallel_size=4,
+ )
+
+ if modality == "image":
+ placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
+ elif modality == "video":
+ placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
+
+ prompts = [
+ (
+ "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n"
+ f"{placeholder}"
+ f"{question}<|assistant|>assistant\n"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
@@ -1120,6 +1155,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"gemma3": run_gemma3,
"glm4v": run_glm4v,
"glm4_1v": run_glm4_1v,
+ "glm4_5v": run_glm4_5v,
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
"internvl_chat": run_internvl,
diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py
index af74877fdbf3..bed8d68bbf4b 100644
--- a/vllm/model_executor/models/glm4_1v.py
+++ b/vllm/model_executor/models/glm4_1v.py
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
@@ -25,6 +26,7 @@
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
+import os
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
@@ -54,9 +56,6 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.quantization.gptq import GPTQConfig
-from vllm.model_executor.layers.quantization.gptq_marlin import (
- GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -66,9 +65,9 @@
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
- PromptUpdate)
+ PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
-from vllm.platforms import _Backend
+from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@@ -81,6 +80,12 @@
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
+is_hpu = current_platform.is_hpu()
+if is_hpu:
+ import habana_frameworks.torch.core as htcore
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+
+
logger = init_logger(__name__)
# For profile run
@@ -169,6 +174,33 @@ class Glm4vVideoEmbeddingInputs(TypedDict):
# === Vision Encoder === #
+class AttentionLongSequence:
+
+ @staticmethod
+ def forward(q, k, v, mask, q_block_size, softmax_mode):
+ """
+ Support long sequence at prompt phase
+ """
+ q_len = q.size(-2)
+ assert q_len % q_block_size == 0
+ q_tiles = (q_len //
+ q_block_size) if (q_len % q_block_size == 0) else math.ceil(
+ q_len / q_block_size)
+ attn_output = torch.zeros_like(q)
+
+ for i in range(q_tiles):
+ s, e = i * q_block_size, (i + 1) * q_block_size
+ row_q = q[:, :, s:e, :]
+ row_mask = mask[:, :, s:e, :]
+ attn_output[:, :,
+ s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0,
+ False, None, softmax_mode)
+ # TODO: markstep after a couple of iterations
+ # need to experiment the optimal number.
+ if i % 75 == 0:
+ htcore.mark_step()
+ return attn_output
+
class Glm4vVisionMLP(nn.Module):
@@ -178,6 +210,7 @@ def __init__(
hidden_features: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -185,13 +218,12 @@ def __init__(
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
- )
- self.down_proj = RowParallelLinear(
- hidden_features,
- in_features,
- bias=bias,
- quant_config=quant_config,
- )
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor):
@@ -269,6 +301,9 @@ def __init__(
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.")
+ self.softmax_mode = 'fp32' if os.environ.get(
+ 'VLLM_FP32_SOFTMAX_VISION', 'false').lower() in ['true', '1' ] else 'None'
+
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -306,6 +341,7 @@ def forward(
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
+ fullattn_mask: Optional[torch.Tensor] = None, # Only used for gaudi
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -342,6 +378,27 @@ def forward(
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
+ elif self.attn_backend == _Backend.TORCH_SDPA and is_hpu:
+ assert cu_seqlens.shape[0] <= 3, "Only support one image plus padding"
+ assert fullattn_mask is not None, \
+ "Should call to here from Glm4vVisionTransformerStaticShape"
+
+ q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
+ (batch_size, _, seq_len_N_t, _) = q1.shape
+ (batch_size, _, seq_len_N_s, _) = k1.shape
+ mask_shape = (batch_size, 1, seq_len_N_t, seq_len_N_s)
+ attn_mask = fullattn_mask.reshape(
+ batch_size, 1, seq_len_N_t, seq_len_N_s,
+ -1)[:, :, :, :, 0] # reshapes the mask to be Bx1xNxN
+ assert attn_mask.shape == mask_shape
+
+ if q1.shape[2] <= 65536: # need to investigate this crosspoint
+ fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0,
+ False, None, self.softmax_mode)
+ else:
+ fused_out = AttentionLongSequence.forward(
+ q1, k1, v1, attn_mask, 64, self.softmax_mode)
+ context_layer = rearrange(fused_out, "b h s d -> b s h d ")
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
@@ -406,6 +463,7 @@ def __init__(
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
)
def forward(
@@ -415,6 +473,7 @@ def forward(
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
+ fullattn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -422,8 +481,8 @@ def forward(
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
+ fullattn_mask=fullattn_mask,
)
-
x = x + self.mlp(self.norm2(x))
return x
@@ -467,25 +526,30 @@ def __init__(
context_dim: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = d_model
self.proj = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=bias,
- gather_output=True)
+ gather_output=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.proj")
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
self.extra_activation_func = nn.GELU()
@@ -527,7 +591,6 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
# Move coordinates to correct device
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
-
# Handle empty sequence case
if total_seq == 0:
adapted_pos_embed = torch.empty(0,
@@ -540,6 +603,7 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
lengths = torch.tensor(lengths,
device=device,
dtype=torch.long)
+
if not isinstance(image_shapes, torch.Tensor):
image_shapes = torch.tensor(image_shapes,
device=device,
@@ -569,7 +633,6 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
w_coords = w_coords.to(device=device, dtype=torch.float32)
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
-
# Create sampling grid
grid = (torch.stack((norm_w, norm_h),
dim=-1).unsqueeze(0).unsqueeze(2))
@@ -675,6 +738,7 @@ def __init__(
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
bias=False,
+ prefix=f"{prefix}.merger",
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)
@@ -718,6 +782,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
).permute(0, 2, 1, 3).flatten())
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
@@ -751,7 +816,6 @@ def forward(
grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
-
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0],
@@ -810,6 +874,184 @@ def load_weights(self, weights: Iterable[tuple[str,
loaded_params.add(name)
return loaded_params
+class Glm4vVisionTransformerStaticShape(Glm4vVisionTransformer):
+ """
+ Here we overwrite some of the methods of Glm4vVisionTransformer
+ to make the model more friendly to static shapes. Specifically,
+ we split the forward method into:
+ - pre_attn (dynamic)
+ - forward (static shape)
+ - post_attn (dynamic)
+ and we should call get_image_embeds instead of forward, allowing
+ the forward method ro run with HPU_Graphs, whereas the
+ pre_attn and post_attn methods are allow to be dynamic.
+ """
+ def pad_multimodal_data(self, pixel_values, image_grid_thw,
+ vision_buckets):
+ desired_number_of_pixels = vision_buckets.get_multimodal_bucket(
+ pixel_values.shape[0])
+ padding_len = desired_number_of_pixels - pixel_values.shape[0]
+ if padding_len <= 0:
+ return pixel_values, image_grid_thw
+
+ logger_msg = "Padding current number pixel " \
+ + str(pixel_values.shape[0]) + " to "+ str(desired_number_of_pixels)
+ logger.info(logger_msg)
+
+ constant_value = -100
+ pixel_values = torch.cat([
+ pixel_values,
+ torch.ones((padding_len, pixel_values.shape[1]),
+ device=pixel_values.device) * constant_value
+ ])
+
+ # ensure W and H can be divided to self.spatial_merge_size
+ if padding_len % (self.spatial_merge_size**2) != 0:
+ raise ValueError("The padding length is not aligned to {self.spatial_merge_size}**2")
+
+ algined_padding_len = padding_len // (self.spatial_merge_size**2)
+ padding_h = algined_padding_len * self.spatial_merge_size
+ padding_w = padding_len // padding_h
+ while padding_w < 8 and ((padding_h // 2) % self.spatial_merge_size) == 0:
+ padding_w *= 2
+ padding_h //= 2
+
+ image_grid_thw = torch.cat([
+ image_grid_thw,
+ torch.tensor([[1, padding_w, padding_h]], device=image_grid_thw.device)
+ ])
+
+ assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels
+ return pixel_values, image_grid_thw
+
+ def forward(self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: Optional[int] = None, # Only used for Flash Attention
+ fullattn_mask: Optional[torch.Tensor] = None, # Only used for Gaudi
+ ) -> torch.Tensor:
+ hidden_states = x.unsqueeze(1)
+ for layer_num, blk in enumerate(self.blocks):
+ htcore.mark_step()
+ hidden_states = blk(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ fullattn_mask=fullattn_mask)
+
+ return hidden_states
+
+
+ def create_block_diagonal_attention_mask_outerprod(self, indices):
+ maxsize = indices[-1]
+ range_to_max_for_each_img = torch.arange(
+ maxsize,
+ device=indices.device).unsqueeze(0).repeat(indices.shape[0] - 1, 1)
+
+ lesser = range_to_max_for_each_img < indices[1:].unsqueeze(1)
+ greater_eq = range_to_max_for_each_img >= indices[:-1].unsqueeze(1)
+ range_indices = torch.logical_and(lesser, greater_eq).float()
+ # can reduce sum externally or as batchmatmul
+ if range_indices.shape[-1] > 40000:
+ log_msg = "einsum running on CPU :" + str(range_indices.shape)
+ logger.info(log_msg)
+ range_indices = range_indices.to("cpu")
+ res = torch.einsum('bi,bj->ij', range_indices, range_indices)
+ res = res.to("hpu")
+ else:
+ res = torch.einsum('bi,bj->ij', range_indices, range_indices)
+
+ return res.bool()
+
+ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor):
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ # hidden_states is [patch_num, patch_pixes]
+
+ hidden_states = self.patch_embed(hidden_states)
+ # hidden_states is [patch_num, patch_dim]
+
+ hidden_states = self.post_conv_layernorm(hidden_states)
+ # hidden_states is [patch_num, patch_dim]
+
+ # compute position embedding
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
+ # rotary_pos_emb is [patch_num, patch_dim/num_heads/2]
+ # image_type_ids is [patch_num, 2--->(w_index, h_index)]
+
+ # compute cu_seqlens
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
+ # cu_seqlens is 1 Dim tensor, [0, w1xH1, w1xH1*2, ..., w1xH1*T, W2xH2, ...]
+
+ max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # seqlens is a list of cu_seqlens values
+
+ hidden_states = self.embeddings(hidden_states, seqlens, grid_thw,
+ image_type_ids[:, 0], image_type_ids[:, 1])
+ # Add a pos embed on hidden_states, shape unchanged
+
+ return (hidden_states, rotary_pos_emb, cu_seqlens, max_seqlen)
+
+ def post_attn(self, hidden_states: torch.Tensor):
+ hidden_states = self.post_layernorm(hidden_states)
+ hidden_states = hidden_states.view(-1, self.spatial_merge_size, self.spatial_merge_size,
+ hidden_states.shape[-1])
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+
+ hidden_states = self.downsample(hidden_states).view(-1, self.out_hidden_size)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states
+
+
+ def get_image_embeds(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ vision_buckets,
+ ) -> torch.Tensor:
+ # first, align the image to 64
+ num_patches = pixel_values.shape[0]
+ assert num_patches % (self.spatial_merge_size**2) == 0, " Patches num not align."
+
+ offset = 0
+ results = []
+ for img_idx in range(grid_thw.shape[0]):
+ img_shape = grid_thw[img_idx, :].unsqueeze(0)
+ curr_img_size = img_shape.prod()
+
+ pixel_values_curr_img = pixel_values[offset:offset + curr_img_size, :]
+ offset += curr_img_size
+
+ pixel_values_curr_img_padded, img_shape_padded = self.pad_multimodal_data(
+ pixel_values_curr_img, img_shape, vision_buckets=vision_buckets)
+ pixel_values_curr_img_padded, rot_pos_emb, cu_seqlens, max_seqlen = \
+ self.pre_attn(pixel_values_curr_img_padded, img_shape_padded)
+
+ fullatt_block_attn_mask = self.create_block_diagonal_attention_mask_outerprod(cu_seqlens)
+ assert pixel_values_curr_img_padded.shape[0] == cu_seqlens[-1] == rot_pos_emb.shape[0]
+
+ htcore.mark_step()
+
+ hidden_states = self.forward(pixel_values_curr_img_padded,
+ rotary_pos_emb=rot_pos_emb,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ fullattn_mask=fullatt_block_attn_mask)
+ htcore.mark_step()
+
+ image_embeds = self.post_attn(hidden_states)
+
+ # slice image_embeds to remove the padded parts
+ pad_index = img_shape_padded[0].prod() // (self.spatial_merge_size**2)
+ results += [image_embeds[:pad_index, :]]
+
+ results_cat = torch.concat(results)
+ return results_cat
class Glm4vProcessingInfo(BaseProcessingInfo):
@@ -950,7 +1192,7 @@ def _get_video_second_idx(self, metadata: dict[str, Any],
total_frames: int) -> list[int]:
video_processor = self.get_video_processor()
- video_fps = metadata.get("fps", 2.0)
+ video_fps = metadata.get("fps", video_processor.fps)
meta_frames = metadata.get("total_num_frames", total_frames)
max_frame_idx = meta_frames - 1
duration = metadata.get("duration",
@@ -1074,7 +1316,6 @@ def _call_hf_processor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
@@ -1121,7 +1362,6 @@ def _call_hf_processor(
prompt="<|begin_of_video|><|video|><|end_of_video|>",
mm_data=video_mm_data,
mm_kwargs=mm_kwargs,
- tok_kwargs=tok_kwargs,
)
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
@@ -1133,11 +1373,8 @@ def _call_hf_processor(
video_placeholder,
)
- grid_t = len(video_outputs["video_grid_thw"])
- _, grid_h, grid_w = video_outputs["video_grid_thw"][0]
- grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
+ video_grid_thw_lst.append(video_outputs["video_grid_thw"])
- video_grid_thw_lst.append(grid_thw)
pixel_values_videos_lst.append(
video_outputs["pixel_values_videos"])
video_outputs = dict(
@@ -1151,7 +1388,6 @@ def _call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
- tok_kwargs=tok_kwargs,
)
combined_outputs = dict(
processed_outputs,
@@ -1213,7 +1449,10 @@ def get_video_replacement_glm4v(item_idx: int):
placeholder.append(eoi_token_id)
placeholder.extend(frame_idx)
placeholder.append(eov_token_id)
- return placeholder
+ return PromptUpdateDetails.select_token_id(
+ placeholder,
+ embed_token_id=hf_processor.video_token_id,
+ )
return [
PromptReplacement(
@@ -1242,10 +1481,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
"k_proj",
"v_proj",
],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
+ "gate_up_proj": ["gate_up_proj"]
}
# To ensure correct weight loading and mapping.
@@ -1265,10 +1501,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.multimodal_config = multimodal_config
- self.visual = Glm4vVisionTransformer(
+ if is_hpu:
+ glm_visionTransformer = Glm4vVisionTransformerStaticShape
+ else:
+ glm_visionTransformer = Glm4vVisionTransformer
+
+ self.visual = glm_visionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
- quant_config=self._maybe_ignore_quant_config(quant_config),
+ quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
@@ -1288,13 +1529,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
- def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
- # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
- # seems to avoid vision encoder sections for some models.
- if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
- return None
- return quant_config
-
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
@@ -1395,7 +1629,15 @@ def _process_image_input(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
- image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+ if is_hpu:
+ assert isinstance(self.visual, Glm4vVisionTransformerStaticShape)
+ image_embeds = self.visual.get_image_embeds(
+ pixel_values,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
@@ -1416,8 +1658,15 @@ def _process_video_input(
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
- video_embeds = self.visual(pixel_values_videos,
- grid_thw=flat_grid_thw)
+ if is_hpu:
+ assert isinstance(self.visual, Glm4vVisionTransformerStaticShape)
+ video_embeds = self.visual.get_image_embeds(
+ pixel_values_videos,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ video_embeds = self.visual(pixel_values_videos, grid_thw=flat_grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
@@ -1589,7 +1838,26 @@ def get_mm_mapping(self) -> MultiModelKeys:
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
- language_model="language_model",
+ language_model="language_model.model",
connector="visual.merger.",
tower_model="visual.",
)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Glm4vMultiModalProcessor,
+ info=Glm4vProcessingInfo,
+ dummy_inputs=Glm4vDummyInputsBuilder,
+)
+class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 47b191007ee1..b9d32d3a1fdd 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -192,7 +192,7 @@
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
- "Glm4v_moeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
+ "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 5b726d0028c4..9283307cb890 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -242,7 +242,8 @@ def _uses_mrope(config: PretrainedConfig) -> bool:
def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
- return _uses_mrope(config) or thinker_uses_mrope(config)
+ return _uses_mrope(config) or _uses_mrope(
+ config.get_text_config()) or thinker_uses_mrope(config)
def thinker_uses_mrope(config: PretrainedConfig) -> bool:
diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py
index 32b0c4d711b6..4cec4a02f00a 100644
--- a/vllm/worker/hpu_model_runner.py
+++ b/vllm/worker/hpu_model_runner.py
@@ -412,7 +412,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
# This is to ensure that we keeps
# the static and dynamic parts distinct.
if htorch.utils.internal.is_lazy():
- if self.model_is_mrope and hasattr(self.model, 'visual'):
+ if 0:#self.model_is_mrope and hasattr(self.model, 'visual'):
logger.info("[Multimodal] Wrapping Visual Model")
self.model.visual = htorch.hpu.wrap_in_hpu_graph(
self.model.visual, disable_tensor_cache=True)