Skip to content

Commit 0e36abf

Browse files
authored
[Bugfix] Correct max tokens for non-contiguous embeds (#21798)
Signed-off-by: Alexandre Milesi <[email protected]> Co-authored-by: Alexandre Milesi <[email protected]>
1 parent 452b2a3 commit 0e36abf

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

vllm/multimodal/profiling.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,14 @@ def _get_dummy_mm_inputs(
180180
def _get_mm_num_tokens(
181181
self,
182182
mm_inputs: MultiModalInputs,
183+
mm_embeddings_only: bool = True,
183184
) -> Mapping[str, int]:
184185
placeholders_by_modality = mm_inputs["mm_placeholders"]
185186

186187
return {
187-
modality: sum(item.get_num_embeds() for item in placeholders)
188+
modality:
189+
sum(item.get_num_embeds() if mm_embeddings_only else item.length
190+
for item in placeholders)
188191
for modality, placeholders in placeholders_by_modality.items()
189192
}
190193

@@ -253,10 +256,11 @@ def get_decoder_dummy_data(
253256
multi_modal_placeholders=mm_inputs["mm_placeholders"],
254257
)
255258

256-
def get_mm_max_tokens(
259+
def _get_mm_max_tokens(
257260
self,
258261
seq_len: int,
259262
mm_counts: Optional[Mapping[str, int]] = None,
263+
mm_embeddings_only: bool = True,
260264
) -> Mapping[str, int]:
261265
if mm_counts is None:
262266
mm_counts = self.get_mm_limits()
@@ -285,4 +289,25 @@ def get_mm_max_tokens(
285289
return max_tokens_per_item
286290

287291
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
288-
return self._get_mm_num_tokens(mm_inputs)
292+
return self._get_mm_num_tokens(mm_inputs,
293+
mm_embeddings_only=mm_embeddings_only)
294+
295+
def get_mm_max_contiguous_tokens(
296+
self,
297+
seq_len: int,
298+
mm_counts: Optional[Mapping[str, int]] = None,
299+
):
300+
"""
301+
Returns the maximum length of the multimodal (image placeholders+text)
302+
tokens, including any break/text tokens in-between image embeddings.
303+
304+
<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
305+
Returns 9, even when the number of image embeddings is 6.
306+
307+
This is important to take into account when profiling and
308+
initializing the encoder cache size.
309+
"""
310+
311+
return self._get_mm_max_tokens(seq_len,
312+
mm_counts,
313+
mm_embeddings_only=False)

vllm/multimodal/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_max_tokens_per_item_by_modality(
129129
seq_len = model_config.max_model_len
130130
mm_limits = self.get_mm_limits_per_prompt(model_config)
131131

132-
return profiler.get_mm_max_tokens(
132+
return profiler.get_mm_max_contiguous_tokens(
133133
seq_len,
134134
{
135135
modality: 1

0 commit comments

Comments
 (0)