Skip to content

Commit d891fa5

Browse files
moyu026liudengjin
andauthored
add qwen2_vl bf16 (#2046)
Co-authored-by: liudengjin <[email protected]>
1 parent c7d1e4d commit d891fa5

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,8 +1230,10 @@ def forward(
12301230
pixel_values = pixel_values.type(self.visual.get_dtype())
12311231
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
12321232
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
1233+
inputs_embeds = inputs_embeds.astype(mindspore.float16)
12331234
image_embeds = image_embeds.to(inputs_embeds.dtype)
12341235
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1236+
inputs_embeds = inputs_embeds.astype(mindspore.bfloat16)
12351237

12361238
if pixel_values_videos is not None:
12371239
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())

0 commit comments

Comments
 (0)