12
12
from vllm .compilation .decorators import support_torch_compile
13
13
from vllm .config import CacheConfig , PoolerConfig , VllmConfig
14
14
from vllm .distributed import get_tensor_model_parallel_world_size
15
- from vllm .forward_context import get_forward_context
16
15
from vllm .model_executor .layers .activation import get_act_fn
17
16
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
18
17
QKVParallelLinear ,
@@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
60
59
def forward (
61
60
self ,
62
61
input_ids : torch .Tensor ,
63
- seq_lens : torch .Tensor ,
64
62
position_ids : torch .Tensor ,
65
63
token_type_ids : Optional [torch .Tensor ] = None ,
66
64
) -> torch .Tensor :
@@ -119,7 +117,6 @@ def forward(
119
117
return pooled_output
120
118
121
119
122
- @support_torch_compile
123
120
class BertEncoder (nn .Module ):
124
121
125
122
def __init__ (self , vllm_config : VllmConfig , prefix : str = "" ):
@@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
337
334
return hidden_states
338
335
339
336
337
+ @support_torch_compile
340
338
class BertModel (nn .Module , SupportsQuant ):
341
339
342
340
is_pooling_model = True
@@ -368,13 +366,9 @@ def forward(
368
366
if inputs_embeds is not None :
369
367
hidden_states = inputs_embeds
370
368
else :
371
- attn_metadata = get_forward_context ().attn_metadata
372
- assert hasattr (attn_metadata , "seq_lens_tensor" )
373
- hidden_states = self .embeddings (
374
- input_ids = input_ids ,
375
- seq_lens = attn_metadata .seq_lens_tensor ,
376
- position_ids = position_ids ,
377
- token_type_ids = token_type_ids )
369
+ hidden_states = self .embeddings (input_ids = input_ids ,
370
+ position_ids = position_ids ,
371
+ token_type_ids = token_type_ids )
378
372
return self .encoder (hidden_states )
379
373
380
374
def _load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
@@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447
441
return loaded_params
448
442
449
443
450
- class BertEmbeddingModel (nn .Module , SupportsV0Only , SupportsQuant ):
444
+ class BertEmbeddingModel (nn .Module , SupportsQuant ):
451
445
"""A model that uses Bert to provide embedding functionalities.
452
446
453
447
This class encapsulates the BertModel and provides an interface for
@@ -474,11 +468,13 @@ def forward(
474
468
self ,
475
469
input_ids : Optional [torch .Tensor ],
476
470
positions : torch .Tensor ,
471
+ token_type_ids : Optional [torch .Tensor ] = None ,
477
472
intermediate_tensors : Optional [IntermediateTensors ] = None ,
478
473
inputs_embeds : Optional [torch .Tensor ] = None ,
479
474
) -> torch .Tensor :
480
475
return self .model (input_ids = input_ids ,
481
476
position_ids = positions ,
477
+ token_type_ids = token_type_ids ,
482
478
inputs_embeds = inputs_embeds ,
483
479
intermediate_tensors = intermediate_tensors )
484
480
0 commit comments