|
15 | 15 | from vllm.model_executor.pooling_metadata import PoolingTensors
|
16 | 16 | from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
17 | 17 | from vllm.transformers_utils.config import (
|
| 18 | + get_classification_activation_function, |
18 | 19 | get_cross_encoder_activation_function)
|
19 | 20 | from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
20 | 21 |
|
@@ -388,15 +389,14 @@ def __init__(
|
388 | 389 | self.classifier = classifier
|
389 | 390 | self.pooler = pooler
|
390 | 391 |
|
391 |
| - if config.task == "score": |
392 |
| - self.default_activation_function = \ |
393 |
| - get_cross_encoder_activation_function(config.hf_config) |
394 |
| - elif config.task == "classify": |
395 |
| - self.default_activation_function = nn.Sigmoid() \ |
396 |
| - if config.hf_config.num_labels == 1 else nn.Softmax() |
397 |
| - else: |
398 |
| - raise NotImplementedError(f"task={config.task!r} is not supported" |
399 |
| - " with the classification pooler") |
| 392 | + self.classification_act_fn = get_classification_activation_function( |
| 393 | + config.hf_config) |
| 394 | + self.cross_encoder_act_fn = get_cross_encoder_activation_function( |
| 395 | + config.hf_config) |
| 396 | + |
| 397 | + def _get_act_fn(self, use_cross_encoder: bool): |
| 398 | + return (self.cross_encoder_act_fn |
| 399 | + if use_cross_encoder else self.classification_act_fn) |
400 | 400 |
|
401 | 401 | def get_prompt_lens(
|
402 | 402 | self,
|
@@ -446,8 +446,28 @@ def forward(
|
446 | 446 | # apply classifier once on the full batch if possible
|
447 | 447 | pooled_output = self.classifier(pooled_output)
|
448 | 448 |
|
449 |
| - # shape: (batch_size, num_labels) |
450 |
| - scores = self.default_activation_function(pooled_output) |
| 449 | + if isinstance(pooling_metadata, V0PoolingMetadata): |
| 450 | + use_cross_encoder_list = [ |
| 451 | + pooling_param.use_cross_encoder |
| 452 | + for _, pooling_param in pooling_metadata.seq_groups |
| 453 | + ] |
| 454 | + else: |
| 455 | + use_cross_encoder_list = [ |
| 456 | + pooling_param.use_cross_encoder |
| 457 | + for pooling_param in pooling_metadata.pooling_params |
| 458 | + ] |
| 459 | + |
| 460 | + # shape of scores: (batch_size, num_labels) |
| 461 | + if all(use_cross_encoder == use_cross_encoder_list[0] |
| 462 | + for use_cross_encoder in use_cross_encoder_list): |
| 463 | + act_fn = self._get_act_fn(use_cross_encoder_list[0]) |
| 464 | + scores = act_fn(pooled_output) |
| 465 | + else: |
| 466 | + scores = torch.stack([ |
| 467 | + self._get_act_fn(use_cross_encoder)(vecs) |
| 468 | + for use_cross_encoder, vecs in zip(use_cross_encoder_list, |
| 469 | + pooled_output) |
| 470 | + ]) |
451 | 471 |
|
452 | 472 | pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
453 | 473 | return PoolerOutput(outputs=pooled_outputs)
|
0 commit comments