Skip to content

Commit da543d1

Browse files
authored
[Model Runner V2] Minor refactoring for EncoderRunner (vllm-project#35628)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
1 parent 87d319c commit da543d1

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

vllm/v1/worker/gpu/mm/encoder_runner.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
class EncoderRunner:
1414
def __init__(
1515
self,
16+
model: SupportsMultiModal,
1617
max_num_tokens: int,
1718
hidden_size: int,
1819
encoder_cache: EncoderCache,
1920
dtype: torch.dtype,
2021
device: torch.device,
2122
):
23+
self.model = model
2224
self.max_num_tokens = max_num_tokens
2325
self.hidden_size = hidden_size
2426
self.encoder_cache = encoder_cache
@@ -48,25 +50,17 @@ def prepare_mm_inputs(
4850
@torch.inference_mode()
4951
def execute_mm_encoder(
5052
self,
51-
model: SupportsMultiModal,
52-
mm_hashes: list[str],
5353
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
5454
) -> list[torch.Tensor]:
55-
if not mm_hashes:
56-
return []
57-
5855
encoder_outputs: list[torch.Tensor] = []
5956
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
6057
mm_kwargs, device=self.device, pin_memory=False
6158
):
62-
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
59+
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
6360
sanity_check_mm_encoder_outputs(
6461
curr_group_outputs, expected_num_items=num_items
6562
)
6663
encoder_outputs.extend(curr_group_outputs)
67-
68-
# Cache the encoder outputs by mm_hash
69-
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
7064
return encoder_outputs
7165

7266
def gather_mm_embeddings(
@@ -146,12 +140,11 @@ def gather_mm_embeddings(
146140
@torch.inference_mode()
147141
def get_inputs_embeds(
148142
self,
149-
model: SupportsMultiModal,
150143
input_ids: torch.Tensor,
151144
mm_embeds: list[torch.Tensor],
152145
is_mm_embed: torch.Tensor,
153146
) -> torch.Tensor:
154-
x = model.embed_input_ids(
147+
x = self.model.embed_input_ids(
155148
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
156149
)
157150
# Copy to the pre-allocated buffer for CUDA graphs.

vllm/v1/worker/gpu/model_states/default.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(
4141

4242
if self.supports_mm_inputs:
4343
assert encoder_cache is not None
44+
self.encoder_cache = encoder_cache
4445
self.encoder_runner = EncoderRunner(
46+
model=self.model,
4547
max_num_tokens=self.max_num_tokens,
4648
hidden_size=self.inputs_embeds_size,
4749
encoder_cache=encoder_cache,
@@ -82,7 +84,12 @@ def get_mm_embeddings(
8284
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
8385
scheduled_encoder_inputs
8486
)
85-
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
87+
if mm_kwargs:
88+
# Execute the multimodal encoder.
89+
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
90+
# Cache the encoder outputs by mm_hash
91+
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
92+
8693
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
8794
input_batch.req_ids,
8895
input_batch.num_tokens,
@@ -92,7 +99,7 @@ def get_mm_embeddings(
9299
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
93100
)
94101
inputs_embeds = self.encoder_runner.get_inputs_embeds(
95-
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
102+
input_batch.input_ids, mm_embeds, is_mm_embed
96103
)
97104
return inputs_embeds[: input_batch.num_tokens_after_padding]
98105

0 commit comments

Comments
 (0)