|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +from typing import Iterable, Optional, Tuple |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +import deepspeed.comm as dist |
| 11 | + |
| 12 | +from ...allocator import empty_from |
| 13 | +from ...inference_utils import ActivationType, DtypeEnum |
| 14 | +from .. import * |
| 15 | +from ...modules.configs import * |
| 16 | +from ...modules.interfaces import * |
| 17 | +from ...ragged import RaggedBatchWrapper |
| 18 | + |
| 19 | +from .container import ExaoneNonTransformerContainer, ExaoneTransformerContainer |
| 20 | + |
| 21 | + |
| 22 | +class ExaoneInferenceModel(DSTransformerModelBase): |
| 23 | + """ |
| 24 | + Inference model implementation for ragged batching for EXAONE 4.0 models. |
| 25 | +
|
| 26 | + Key features: |
| 27 | + - Hybrid attention: sliding_attention (local) vs full_attention (global) layers |
| 28 | + - QK-Reorder-Norm: RMSNorm applied after Q/K projections |
| 29 | + - Conditional RoPE: Skip RoPE for full_attention layers |
| 30 | + - Grouped Query Attention: 40 query heads, 8 key-value heads |
| 31 | + """ |
| 32 | + |
| 33 | + _non_transformer: Optional[ExaoneNonTransformerContainer] |
| 34 | + """ |
| 35 | + Embed + unembed container. Specializing the type annotation. |
| 36 | + """ |
| 37 | + |
| 38 | + _transformer: Optional[Iterable[ExaoneTransformerContainer]] |
| 39 | + """ |
| 40 | + Per-layer transformer container. Specializing the type annotation. |
| 41 | + """ |
| 42 | + |
| 43 | + # EXAONE 4.0 specific attributes |
| 44 | + _layer_types: Optional[list] = None |
| 45 | + """ |
| 46 | + Layer types for hybrid attention: 'sliding_attention' or 'full_attention' |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, config, engine_config, base_mp_group): |
| 50 | + super().__init__(config, engine_config, base_mp_group) |
| 51 | + |
| 52 | + # Store layer types for hybrid attention handling |
| 53 | + if hasattr(self._config, 'layer_types'): |
| 54 | + self._layer_types = self._config.layer_types |
| 55 | + else: |
| 56 | + # Fallback: infer from sliding_window_pattern (LLLG = 3 local, 1 global) |
| 57 | + pattern = getattr(self._config, 'sliding_window_pattern', 'LLLG') |
| 58 | + layer_types = [] |
| 59 | + for i in range(self.num_layers): |
| 60 | + if pattern[i % len(pattern)] == 'G': |
| 61 | + layer_types.append('full_attention') |
| 62 | + else: |
| 63 | + layer_types.append('sliding_attention') |
| 64 | + self._layer_types = layer_types |
| 65 | + |
| 66 | + """ |
| 67 | + Properties inherited from `DSInferenceModelBase` |
| 68 | + """ |
| 69 | + |
| 70 | + @property |
| 71 | + def max_sequence_length(self) -> int: |
| 72 | + return self._config.max_position_embeddings |
| 73 | + |
| 74 | + """ |
| 75 | + Properties inherited from `DSTransformerModelBase` |
| 76 | + """ |
| 77 | + |
| 78 | + @property |
| 79 | + def num_layers(self) -> int: |
| 80 | + return self._config.num_hidden_layers |
| 81 | + |
| 82 | + @property |
| 83 | + def model_dim(self) -> int: |
| 84 | + return self._config.hidden_size |
| 85 | + |
| 86 | + @property |
| 87 | + def vocab_size(self) -> int: |
| 88 | + return self._config.vocab_size |
| 89 | + |
| 90 | + @property |
| 91 | + def head_size(self) -> int: |
| 92 | + return getattr(self._config, 'head_dim', self.model_dim // self.n_heads) |
| 93 | + |
| 94 | + @property |
| 95 | + def n_heads(self) -> int: |
| 96 | + return self._config.num_attention_heads |
| 97 | + |
| 98 | + @property |
| 99 | + def intermediate_dim(self) -> int: |
| 100 | + return self._config.intermediate_size |
| 101 | + |
| 102 | + @property |
| 103 | + def n_heads_kv(self) -> int: |
| 104 | + return self._config.num_key_value_heads |
| 105 | + |
| 106 | + @property |
| 107 | + def activation_dtype(self) -> DtypeEnum: |
| 108 | + if self._config.torch_dtype == torch.float16: |
| 109 | + return DtypeEnum.fp16 |
| 110 | + elif self._config.torch_dtype == torch.bfloat16: |
| 111 | + return DtypeEnum.bf16 |
| 112 | + else: |
| 113 | + raise NotImplementedError("Only fp16 and bf16 are supported") |
| 114 | + |
| 115 | + @property |
| 116 | + def mlp_activation_fn(self) -> ActivationType: |
| 117 | + activation = self._config.hidden_act.lower() |
| 118 | + # EXAONE 4.0 uses gated SiLU activation like LLaMA |
| 119 | + if activation == "silu": |
| 120 | + return ActivationType.SiGLU |
| 121 | + elif activation == "gelu": |
| 122 | + return ActivationType.GEGLU |
| 123 | + elif activation == "relu": |
| 124 | + return ActivationType.ReGLU |
| 125 | + else: |
| 126 | + raise NotImplementedError(f"Activation {activation} not supported") |
| 127 | + |
| 128 | + @property |
| 129 | + def norm_type(self) -> NormTypeEnum: |
| 130 | + return NormTypeEnum.RMSNorm |
| 131 | + |
| 132 | + @property |
| 133 | + def positional_embedding_type(self) -> PositionalEmbeddingType: |
| 134 | + return PositionalEmbeddingType.rotate_half |
| 135 | + |
| 136 | + @property |
| 137 | + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: |
| 138 | + return RotateHalfConfig(theta_base=self._config.rope_theta) |
| 139 | + |
| 140 | + """ |
| 141 | + Helper methods for EXAONE 4.0 specific features |
| 142 | + """ |
| 143 | + |
| 144 | + def is_global_attention_layer(self, layer_idx: int) -> bool: |
| 145 | + """Check if layer uses global (full) attention vs local (sliding) attention""" |
| 146 | + if self._layer_types and layer_idx < len(self._layer_types): |
| 147 | + return self._layer_types[layer_idx] == 'full_attention' |
| 148 | + return False |
| 149 | + |
| 150 | + def should_apply_rope(self, layer_idx: int) -> bool: |
| 151 | + """EXAONE 4.0 skips RoPE for global attention layers""" |
| 152 | + return not self.is_global_attention_layer(layer_idx) |
| 153 | + |
| 154 | + """ |
| 155 | + Forward implementations |
| 156 | + """ |
| 157 | + |
| 158 | + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: |
| 159 | + """ |
| 160 | + Performs the embedding lookup prior to running the transformer of the model. |
| 161 | +
|
| 162 | + Arguments: |
| 163 | + ragged_batch (RaggedBatchWrapper): The batch to embed. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + torch.Tensor: The embedded batch. |
| 167 | + """ |
| 168 | + embed = self.embed(ragged_batch, self._non_transformer.word_emb) |
| 169 | + |
| 170 | + if embed.shape[-1] != self.model_dim: |
| 171 | + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") |
| 172 | + |
| 173 | + return embed |
| 174 | + |
| 175 | + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, |
| 176 | + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: |
| 177 | + """ |
| 178 | + Executes one transformer layer with EXAONE 4.0 specific features: |
| 179 | + - Hybrid attention (sliding vs full) |
| 180 | + - QK-Reorder-Norm (RMSNorm after Q/K projections) |
| 181 | + - Conditional RoPE (skip for global layers) |
| 182 | +
|
| 183 | + Arguments: |
| 184 | + layer_idx (int): The index of the layer to execute. |
| 185 | + residual (torch.Tensor): The residual tensor from the previous layer. |
| 186 | + hidden_states (torch.Tensor): The hidden states from the previous layer. |
| 187 | + ragged_batch_info (RaggedBatchWrapper): The batch metadata. |
| 188 | + """ |
| 189 | + cur_params = self._transformer[layer_idx] |
| 190 | + kv_cache = self.state_manager.get_cache(layer_idx) |
| 191 | + |
| 192 | + # QKV projection |
| 193 | + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) |
| 194 | + |
| 195 | + # EXAONE 4.0 attention with hybrid pattern and conditional RoPE |
| 196 | + # NOTE: The attention module should handle QK-Reorder-Norm internally |
| 197 | + # and respect the RoPE configuration based on layer type |
| 198 | + if self.is_global_attention_layer(layer_idx): |
| 199 | + # Global attention: full attention, no RoPE |
| 200 | + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info, apply_rotary_pos_emb=False) |
| 201 | + else: |
| 202 | + # Local attention: sliding window, with RoPE |
| 203 | + hidden_states = self.attn(hidden_states, |
| 204 | + kv_cache, |
| 205 | + ragged_batch_info, |
| 206 | + apply_rotary_pos_emb=True, |
| 207 | + sliding_window=getattr(self._config, 'sliding_window', 4096)) |
| 208 | + |
| 209 | + # Attention output projection |
| 210 | + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) |
| 211 | + |
| 212 | + if self.tp_size > 1: |
| 213 | + dist.all_reduce(hidden_states, group=self._base_mp_group) |
| 214 | + |
| 215 | + # Post-attention normalization |
| 216 | + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) |
| 217 | + |
| 218 | + # MLP forward pass (gated SiLU) |
| 219 | + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) |
| 220 | + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) |
| 221 | + |
| 222 | + if self.tp_size > 1: |
| 223 | + dist.all_reduce(hidden_states, group=self._base_mp_group) |
| 224 | + |
| 225 | + # Prepare for next layer normalization |
| 226 | + if layer_idx != self.num_layers - 1: |
| 227 | + next_params = self._transformer[layer_idx + 1] |
| 228 | + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) |
| 229 | + else: |
| 230 | + # On last layer, just perform the residual add |
| 231 | + residual.add_(hidden_states) |
| 232 | + |
| 233 | + return residual, hidden_states |
| 234 | + |
| 235 | + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: |
| 236 | + """ |
| 237 | + Performs unembedding of the hidden states to logits. This will only sample the final |
| 238 | + token of each sequence. |
| 239 | + """ |
| 240 | + logits = self.unembed(hidden_states, |
| 241 | + self._non_transformer.word_unembed, |
| 242 | + ragged_batch_info, |
| 243 | + gamma=self._non_transformer.final_norm) |
| 244 | + |
| 245 | + if self.tp_size > 1: |
| 246 | + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) |
| 247 | + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) |
| 248 | + |
| 249 | + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) |
| 250 | + |
| 251 | + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) |
| 252 | + |
| 253 | + return full_logits |
| 254 | + else: |
| 255 | + return logits |
| 256 | + |
| 257 | + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: |
| 258 | + """ |
| 259 | + Forward pass for EXAONE 4.0 model with hybrid attention support. |
| 260 | + """ |
| 261 | + residual = self._forward_embed(wrapped_batch) |
| 262 | + |
| 263 | + # Initial normalization |
| 264 | + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) |
| 265 | + |
| 266 | + # Forward through all transformer layers |
| 267 | + for layer_idx in range(self.num_layers): |
| 268 | + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, |
| 269 | + wrapped_batch) |
| 270 | + |
| 271 | + return self._forward_unembed(residual, wrapped_batch) |
0 commit comments