-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Model] Add Ernie4.5 VL Model Support #22514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,463
−0
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
dee4372
[Model] Add Ernie4.5 VL
CSWYF3634076 af1d864
[Model] Add Ernie4.5 VL v2 annotation organization
CSWYF3634076 99773ff
[Model] Add Ernie4.5 VL v2 annotation organization
CSWYF3634076 8442d5d
[Model] Add Ernie4.5 VL v3 fix variable name
CSWYF3634076 42004bc
[Model] Add Ernie4.5 VL v4 fix code-assist issue
CSWYF3634076 9a509cd
fix format by pre-commit
CSWYF3634076 8d3d62b
[Model] Add Ernie4.5 VL v5 fix format by pre-commit
CSWYF3634076 c227368
[Model] Add Ernie4.5 VL v5 fix format by pre-commit
CSWYF3634076 080f818
[Model] Add Ernie4.5 VL v5 fix format by pre-commit
CSWYF3634076 01f2231
[Model] Add Ernie4.5 VL v5 add trust_remote_code tag
CSWYF3634076 d4ee345
[Model] Add Ernie4.5 VL v6 rename and fix comments
CSWYF3634076 7ea25db
[Model] Add Ernie4.5 VL v7 vit qkv replace with QKVParallelinear
CSWYF3634076 e124b87
[Model] Add Ernie4.5 VL v8 delete processor file
CSWYF3634076 0fb8105
[Model] Add Ernie4.5 VL v9 pixel_values norm
CSWYF3634076 35fe906
[Model] Add Ernie4.5 VL v9 delete _get_image_processor_kwargs
CSWYF3634076 02754b7
Merge branch 'main' into vl
CSWYF3634076 4943465
Merge branch 'main' into vl
CSWYF3634076 9c6a49d
[Model] Add Ernie4.5 VL v10 adapt main
CSWYF3634076 7e5ac16
[Model] Add Ernie4.5 VL v11 test file
CSWYF3634076 0bedaa6
[Model] Add Ernie4.5 VL v12 pre-commit
CSWYF3634076 98bd72f
Merge branch 'main' into vl
CSWYF3634076 69a2902
Merge branch 'main' into vl
CSWYF3634076 faad7fe
[Model] Add Ernie4.5 VL v13 no test_common
CSWYF3634076 a4a1817
[Model] Add Ernie4.5 VL v14 add model_id to test_common
CSWYF3634076 4c5abbb
[Model] Add Ernie4.5 VL v15 skip test_can_initialize due to processor…
CSWYF3634076 3b70302
[Model] Add Ernie4.5 VL v16 add decord to test.in
CSWYF3634076 a08137c
[Model] Add Ernie4.5 VL v17 update test.txt by pre-commit
CSWYF3634076 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
|
||
from .common import apply_rotary_emb_dispatch | ||
from .mrope import MRotaryEmbedding | ||
|
||
|
||
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): | ||
"""3D rotary positional embedding. 3D is t:time h:height w:width""" | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
query: torch.Tensor, | ||
key: Optional[torch.Tensor] = None, | ||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
assert positions.ndim == 1 or positions.ndim == 2 | ||
assert key is not None | ||
|
||
num_tokens = positions.shape[-1] | ||
cos_sin = self.cos_sin_cache[positions] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
if positions.ndim == 2: | ||
assert self.mrope_section | ||
|
||
section_h = self.mrope_section[0] # 22 | ||
section_w = self.mrope_section[1] # 22 | ||
section_t = self.mrope_section[2] # 20 | ||
assert section_h == section_w | ||
# Split according to [h w h w h w h w... t t t...] | ||
section_cos_t = cos[..., -section_t:] | ||
section_cos_h = cos[..., :section_h + section_w:2] | ||
section_cos_w = cos[..., 1:section_h + section_w:2] | ||
|
||
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ | ||
1], section_cos_w[2] | ||
cos_hw = torch.stack([cos_h, cos_w], | ||
dim=-1).reshape(cos_h.shape[:-1] + | ||
(cos_h.shape[-1] * 2, )) | ||
cos = torch.cat([cos_hw, cos_t], dim=-1) | ||
|
||
section_sin_t = sin[..., -section_t:] | ||
section_sin_h = sin[..., :section_h + section_w:2] | ||
section_sin_w = sin[..., 1:section_h + section_w:2] | ||
|
||
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ | ||
1], section_sin_w[2] | ||
sin_hw = torch.stack([sin_h, sin_w], | ||
dim=-1).reshape(sin_h.shape[:-1] + | ||
(sin_h.shape[-1] * 2, )) | ||
sin = torch.cat([sin_hw, sin_t], dim=-1) | ||
|
||
query_shape = query.shape | ||
query = query.view(num_tokens, -1, self.head_size) | ||
query_rot = query[..., :self.rotary_dim] | ||
query_pass = query[..., self.rotary_dim:] | ||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, | ||
self.is_neox_style) | ||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) | ||
|
||
key_shape = key.shape | ||
key = key.view(num_tokens, -1, self.head_size) | ||
key_rot = key[..., :self.rotary_dim] | ||
key_pass = key[..., self.rotary_dim:] | ||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, | ||
self.is_neox_style) | ||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) | ||
return query, key |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -393,6 +393,15 @@ def get_input_positions_tensor( | |
context_len=context_len, | ||
seq_len=seq_len, | ||
) | ||
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: | ||
return cls._ernie_get_input_positions_tensor( | ||
input_tokens=input_tokens, | ||
hf_config=hf_config, | ||
image_grid_thw=image_grid_thw, | ||
video_grid_thw=video_grid_thw, | ||
context_len=context_len, | ||
seq_len=seq_len, | ||
) | ||
else: | ||
return cls._vl_get_input_positions_tensor( | ||
input_tokens=input_tokens, | ||
|
@@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor( | |
len(input_tokens)).item() | ||
return llm_positions, mrope_position_delta | ||
|
||
@classmethod | ||
def _ernie_get_input_positions_tensor( | ||
Comment on lines
+525
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also have a feeling that a lot of operations in here can be optimized by vectorization, but we can revisit this later |
||
cls, | ||
input_tokens: list[int], | ||
hf_config: PretrainedConfig, | ||
image_grid_thw: Union[list[list[int]], torch.Tensor], | ||
video_grid_thw: Union[list[list[int]], torch.Tensor], | ||
context_len: int = 0, | ||
seq_len: Optional[int] = None, | ||
) -> tuple[torch.Tensor, int]: | ||
"""Get mrope input positions and delta value for Ernie VL.""" | ||
|
||
image_token_id = hf_config.im_patch_id | ||
video_start_token_id = hf_config.video_start_token_id | ||
video_end_token_id = hf_config.video_end_token_id | ||
spatial_conv_size = hf_config.spatial_conv_size | ||
temporal_conv_size = hf_config.temporal_conv_size | ||
llm_pos_ids_list: list = [] | ||
|
||
if not (image_grid_thw is None and video_grid_thw is None): | ||
if isinstance(image_grid_thw, torch.Tensor): | ||
image_grid_thw = image_grid_thw.tolist() | ||
|
||
input_token_type: list[str] = [] | ||
video_check_flg = False | ||
for token in input_tokens: | ||
if token == video_start_token_id: | ||
video_check_flg = True | ||
elif token == video_end_token_id: | ||
video_check_flg = False | ||
|
||
if (token == image_token_id) and (video_check_flg is False): | ||
input_token_type.append("image") | ||
elif (token == image_token_id) and (video_check_flg is True): | ||
input_token_type.append("video") | ||
else: | ||
input_token_type.append("text") | ||
|
||
input_type_group: list[tuple[str, int, int]] = [] | ||
for key, group_iter in itertools.groupby( | ||
enumerate(input_token_type), lambda x: x[1]): | ||
group_list = list(group_iter) | ||
start_index = group_list[0][0] | ||
end_index = group_list[-1][0] + 1 | ||
input_type_group.append((key, start_index, end_index)) | ||
|
||
video_frame_num = 1 | ||
mm_data_idx = 0 | ||
for modality_type, start_idx, end_idx in input_type_group: | ||
st_idx = llm_pos_ids_list[-1].max() + 1 if len( | ||
llm_pos_ids_list) > 0 else 0 | ||
if modality_type == "image": | ||
t, h, w = ( | ||
image_grid_thw[mm_data_idx][0], | ||
image_grid_thw[mm_data_idx][1], | ||
image_grid_thw[mm_data_idx][2], | ||
) | ||
llm_grid_t, llm_grid_h, llm_grid_w = \ | ||
t, h // spatial_conv_size, w // spatial_conv_size | ||
|
||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand( | ||
-1, llm_grid_h * llm_grid_w).flatten() | ||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( | ||
llm_grid_t, -1, llm_grid_w).flatten() | ||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( | ||
llm_grid_t, llm_grid_h, -1).flatten() | ||
llm_pos_ids_list.append( | ||
torch.stack([t_index, h_index, w_index]) + st_idx) | ||
mm_data_idx += 1 | ||
|
||
elif modality_type == "video": | ||
t, h, w = ( | ||
video_grid_thw[mm_data_idx][0], | ||
video_grid_thw[mm_data_idx][1], | ||
video_grid_thw[mm_data_idx][2], | ||
) | ||
llm_grid_t, llm_grid_h, llm_grid_w = (t // | ||
temporal_conv_size, | ||
h // | ||
spatial_conv_size, | ||
w // | ||
spatial_conv_size) | ||
|
||
for t_idx in range(llm_grid_t): | ||
t_index = torch.tensor(t_idx).view(-1, 1).expand( | ||
-1, llm_grid_h * llm_grid_w).flatten() | ||
h_index = torch.arange(llm_grid_h).view( | ||
1, -1, 1).expand(1, -1, llm_grid_w).flatten() | ||
w_index = torch.arange(llm_grid_w).view( | ||
1, 1, -1).expand(1, llm_grid_h, -1).flatten() | ||
llm_pos_ids_list.append( | ||
torch.stack([t_index, h_index, w_index]) + st_idx) | ||
|
||
mm_data_idx += 1 | ||
video_frame_num += 1 | ||
|
||
else: | ||
text_len = end_idx - start_idx | ||
llm_pos_ids_list.append( | ||
torch.arange(text_len).view(1, -1).expand(3, -1) + | ||
st_idx) | ||
video_frame_num = 1 | ||
|
||
else: | ||
text_len = len(input_tokens) | ||
llm_pos_ids_list.append( | ||
torch.arange(text_len).view(1, -1).expand(3, -1)) | ||
|
||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
llm_positions = llm_positions[:, context_len:seq_len] | ||
mrope_position_delta = (llm_positions.max() + 1 - | ||
len(input_tokens)).item() | ||
return llm_positions, mrope_position_delta | ||
|
||
@classmethod | ||
def _vl_get_input_positions_tensor( | ||
cls, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.