Skip to content

Commit 1cf3753

Browse files
[MODEL] Apertus and XIELU (#23068)
Signed-off-by: EduardDurech <[email protected]> Co-authored-by: AllenHaoHuang <[email protected]>
1 parent 4f7cde7 commit 1cf3753

File tree

6 files changed

+696
-1
lines changed

6 files changed

+696
-1
lines changed

tests/models/language/generation/test_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@
9292
pytest.param(
9393
"allenai/OLMoE-1B-7B-0924-Instruct",
9494
marks=[pytest.mark.cpu_model],
95-
)
95+
),
96+
pytest.param("swiss-ai/Apertus-8B"), # apertus
9697
])
9798
@pytest.mark.parametrize("max_tokens", [32])
9899
@pytest.mark.parametrize("num_logprobs", [5])

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def check_available_online(
137137
# yapf: disable
138138
_TEXT_GENERATION_EXAMPLE_MODELS = {
139139
# [Decoder-only]
140+
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B",
141+
min_transformers_version="4.56.0",
142+
trust_remote_code=True),
140143
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
141144
trust_remote_code=True),
142145
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",

tests/models/test_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
2626
def test_registry_imports(model_arch):
27+
# Skip if transformers version is incompatible
28+
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
29+
model_info.check_transformers_version(on_fail="skip")
2730
# Ensure all model classes can be imported successfully
2831
model_cls = ModelRegistry._try_load_model_cls(model_arch)
2932
assert model_cls is not None

vllm/model_executor/layers/activation.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1212
get_tensor_model_parallel_world_size)
13+
from vllm.logger import init_logger
1314
from vllm.model_executor.custom_op import CustomOp
1415
from vllm.model_executor.utils import set_weight_attrs
1516
from vllm.platforms import current_platform
1617
from vllm.utils import LazyDict
1718

19+
logger = init_logger(__name__)
20+
1821

1922
@CustomOp.register("fatrelu_and_mul")
2023
class FatreluAndMul(CustomOp):
@@ -363,6 +366,112 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
363366
return self.forward_native(x)
364367

365368

369+
@CustomOp.register("xielu")
370+
class XIELU(CustomOp):
371+
"""
372+
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
373+
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
374+
Otherwise, we emit a single warning and use xIELU Python
375+
"""
376+
377+
def __init__(
378+
self,
379+
alpha_p_init: float = 0.8,
380+
alpha_n_init: float = 0.8,
381+
beta: float = 0.5,
382+
eps: float = -1e-6,
383+
dtype: torch.dtype = torch.bfloat16,
384+
with_vector_loads: bool = False,
385+
):
386+
super().__init__()
387+
self.alpha_p = nn.Parameter(
388+
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
389+
1).unsqueeze(0))
390+
self.alpha_n = nn.Parameter(
391+
torch.log(
392+
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
393+
1).unsqueeze(0))
394+
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
395+
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
396+
self.with_vector_loads = with_vector_loads
397+
# Temporary until xIELU CUDA fully implemented
398+
self._beta_scalar = float(self.beta.detach().cpu().float().item())
399+
self._eps_scalar = float(self.eps.detach().cpu().float().item())
400+
401+
self._xielu_cuda_obj = None
402+
try:
403+
import xielu.ops # noqa: F401
404+
405+
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
406+
msg = "Using experimental xIELU CUDA."
407+
try:
408+
from torch._dynamo import allow_in_graph
409+
410+
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
411+
msg += " Enabled torch._dynamo for xIELU CUDA."
412+
except Exception as err:
413+
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
414+
"this may result in slower performance.")
415+
self._xielu_cuda_fn = self._xielu_cuda
416+
logger.warning_once(msg)
417+
except Exception as err:
418+
logger.warning_once(
419+
"CUDA-fused xIELU not available (%s) –"
420+
" falling back to a Python version.\n"
421+
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
422+
str(err),
423+
)
424+
425+
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
426+
alpha_p = nn.functional.softplus(self.alpha_p)
427+
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
428+
return torch.where(
429+
x > 0,
430+
alpha_p * x * x + self.beta * x,
431+
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
432+
self.beta * x,
433+
)
434+
435+
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
436+
"""Firewall function to prevent torch.compile from seeing .item()"""
437+
assert self._xielu_cuda_obj is not None, (
438+
"XIELU CUDA object must not be None")
439+
original_shape = x.shape
440+
# CUDA kernel expects 3D tensors, reshape if needed
441+
while x.dim() < 3:
442+
x = x.unsqueeze(0)
443+
if x.dim() > 3:
444+
x = x.view(-1, 1, x.size(-1))
445+
if original_shape != x.shape:
446+
logger.warning_once(
447+
"Warning: xIELU input tensor expects 3 dimensions"
448+
" but got (shape: %s). Reshaping to (shape: %s).",
449+
original_shape,
450+
x.shape,
451+
)
452+
result = self._xielu_cuda_obj.forward(
453+
x,
454+
self.alpha_p,
455+
self.alpha_n,
456+
# Temporary until xIELU CUDA fully implemented ->
457+
# self.{beta,eps}.item()
458+
self._beta_scalar,
459+
self._eps_scalar,
460+
self.with_vector_loads,
461+
)
462+
return result.view(original_shape)
463+
464+
def forward(self, input: torch.Tensor) -> torch.Tensor:
465+
if self._xielu_cuda_obj is not None and input.is_cuda:
466+
if not torch._dynamo.is_compiling():
467+
return self._xielu_cuda_fn(input)
468+
else:
469+
logger.warning_once(
470+
"torch._dynamo is compiling, using Python version of xIELU."
471+
)
472+
return self._xielu_python(input)
473+
474+
366475
class ScaledActivation(nn.Module):
367476
"""An activation function with post-scale parameters.
368477
@@ -426,6 +535,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
426535
lambda: nn.Tanh(),
427536
"sigmoid":
428537
lambda: nn.Sigmoid(),
538+
"xielu":
539+
lambda: XIELU(),
429540
})
430541

431542

0 commit comments

Comments
 (0)