Skip to content

Commit ef075c5

Browse files
committed
Fix EXAONE 4.0 policy container mapping issue
1 parent a54c394 commit ef075c5

File tree

10 files changed

+1251
-0
lines changed

10 files changed

+1251
-0
lines changed

deepspeed/inference/v2/engine_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
QwenPolicy,
2525
Qwen2Policy,
2626
Qwen2MoePolicy,
27+
ExaonePolicy,
2728
)
2829
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
2930
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
@@ -129,6 +130,12 @@ def build_hf_engine(path: str,
129130
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
130131
elif model_config.model_type == "qwen2_moe":
131132
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
133+
elif model_config.model_type == "exaone4":
134+
# Ensure we're using the correct version of transformers for EXAONE 4.0
135+
import transformers
136+
assert version.parse(transformers.__version__) >= version.parse("4.54.0"), \
137+
f"EXAONE 4.0 requires transformers >= 4.54.0, you have version {transformers.__version__}"
138+
policy = ExaonePolicy(model_config, checkpoint_engine=checkpoint_engine)
132139
else:
133140
raise ValueError(f"Unsupported model type {model_config.model_type}")
134141

deepspeed/inference/v2/model_implementations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from .qwen import *
2020
from .qwen_v2 import *
2121
from .qwen_v2_moe import *
22+
from .exaone import *
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .container import ExaoneTransformerContainer, ExaoneNonTransformerContainer
7+
from .model import ExaoneInferenceModel
8+
from .policy import ExaonePolicy
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
# Create a container object to save model-specific tensors for EXAONE 4.0
7+
8+
from ..common_parameters import *
9+
from ..layer_container_base import LayerContainer
10+
"""
11+
HF EXAONE 4.0 model structure:
12+
13+
Exaone4ForCausalLM(
14+
(model): Exaone4Model(
15+
(embed_tokens): Embedding(102400, 5120)
16+
(layers): ModuleList(
17+
(0-63): 64 x Exaone4DecoderLayer(
18+
(self_attn): Exaone4Attention(
19+
(q_proj): Linear(in_features=5120, out_features=5120, bias=False)
20+
(k_proj): Linear(in_features=5120, out_features=1024, bias=False)
21+
(v_proj): Linear(in_features=5120, out_features=1024, bias=False)
22+
(o_proj): Linear(in_features=5120, out_features=5120, bias=False)
23+
(rotary_emb): Exaone4RotaryEmbedding()
24+
)
25+
(mlp): Exaone4MLP(
26+
(gate_proj): Linear(in_features=5120, out_features=27392, bias=False)
27+
(up_proj): Linear(in_features=5120, out_features=27392, bias=False)
28+
(down_proj): Linear(in_features=27392, out_features=5120, bias=False)
29+
(act_fn): SiLUActivation()
30+
)
31+
(input_layernorm): Exaone4RMSNorm()
32+
(post_attention_layernorm): Exaone4RMSNorm()
33+
)
34+
)
35+
(norm): Exaone4RMSNorm()
36+
)
37+
(lm_head): Linear(in_features=5120, out_features=102400, bias=False)
38+
)
39+
40+
Key EXAONE 4.0 features:
41+
- Hybrid attention: sliding_attention (local) vs full_attention (global) layers
42+
- Grouped Query Attention: 40 query heads, 8 key-value heads
43+
- QK-Reorder-Norm: RMSNorm applied after Q/K projections
44+
- SiLU activation in MLP
45+
"""
46+
47+
48+
class ExaoneTransformerContainer(LayerContainer):
49+
"""
50+
Transformer layer container for the EXAONE 4.0 model.
51+
Handles both sliding_attention and full_attention layer types.
52+
"""
53+
qkv_w: UnfusedQKVParameter
54+
attn_out_w: AttentionOutputParameter
55+
mlp_1_w: GatedMLPParameter
56+
mlp_2_w: MLP2Parameter
57+
attn_norm_gamma: NormParameter
58+
mlp_norm_gamma: NormParameter
59+
60+
PARAM_MAPPING = {
61+
# Attention parameters - Q, K, V projections
62+
"self_attn.q_proj.weight": "qkv_w.q_params",
63+
"self_attn.k_proj.weight": "qkv_w.k_params",
64+
"self_attn.v_proj.weight": "qkv_w.v_params",
65+
"self_attn.o_proj.weight": "attn_out_w.params",
66+
67+
# MLP parameters - gate, up, down projections
68+
"mlp.gate_proj.weight": "mlp_1_w.gate_params",
69+
"mlp.up_proj.weight": "mlp_1_w.up_params",
70+
"mlp.down_proj.weight": "mlp_2_w.params",
71+
72+
# Normalization parameters
73+
"input_layernorm.weight": "attn_norm_gamma.params",
74+
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
75+
}
76+
77+
78+
class ExaoneNonTransformerContainer(LayerContainer):
79+
"""
80+
Non-Transformer layer container for the EXAONE 4.0 model.
81+
Contains embedding, final normalization, and output projection parameters.
82+
"""
83+
word_emb: EmbeddingParameter
84+
word_unembed: UnembedParameter
85+
final_norm: NormParameter
86+
87+
PARAM_MAPPING = {
88+
# Embedding and output parameters
89+
"model.embed_tokens.weight": "word_emb.params",
90+
"model.norm.weight": "final_norm.params",
91+
"lm_head.weight": "word_unembed.params",
92+
}
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from typing import Iterable, Optional, Tuple
7+
8+
import torch
9+
10+
import deepspeed.comm as dist
11+
12+
from ...allocator import empty_from
13+
from ...inference_utils import ActivationType, DtypeEnum
14+
from .. import *
15+
from ...modules.configs import *
16+
from ...modules.interfaces import *
17+
from ...ragged import RaggedBatchWrapper
18+
19+
from .container import ExaoneNonTransformerContainer, ExaoneTransformerContainer
20+
21+
22+
class ExaoneInferenceModel(DSTransformerModelBase):
23+
"""
24+
Inference model implementation for ragged batching for EXAONE 4.0 models.
25+
26+
Key features:
27+
- Hybrid attention: sliding_attention (local) vs full_attention (global) layers
28+
- QK-Reorder-Norm: RMSNorm applied after Q/K projections
29+
- Conditional RoPE: Skip RoPE for full_attention layers
30+
- Grouped Query Attention: 40 query heads, 8 key-value heads
31+
"""
32+
33+
_non_transformer: Optional[ExaoneNonTransformerContainer]
34+
"""
35+
Embed + unembed container. Specializing the type annotation.
36+
"""
37+
38+
_transformer: Optional[Iterable[ExaoneTransformerContainer]]
39+
"""
40+
Per-layer transformer container. Specializing the type annotation.
41+
"""
42+
43+
# EXAONE 4.0 specific attributes
44+
_layer_types: Optional[list] = None
45+
"""
46+
Layer types for hybrid attention: 'sliding_attention' or 'full_attention'
47+
"""
48+
49+
def __init__(self, config, engine_config, base_mp_group):
50+
super().__init__(config, engine_config, base_mp_group)
51+
52+
# Store layer types for hybrid attention handling
53+
if hasattr(self._config, 'layer_types'):
54+
self._layer_types = self._config.layer_types
55+
else:
56+
# Fallback: infer from sliding_window_pattern (LLLG = 3 local, 1 global)
57+
pattern = getattr(self._config, 'sliding_window_pattern', 'LLLG')
58+
layer_types = []
59+
for i in range(self.num_layers):
60+
if pattern[i % len(pattern)] == 'G':
61+
layer_types.append('full_attention')
62+
else:
63+
layer_types.append('sliding_attention')
64+
self._layer_types = layer_types
65+
66+
"""
67+
Properties inherited from `DSInferenceModelBase`
68+
"""
69+
70+
@property
71+
def max_sequence_length(self) -> int:
72+
return self._config.max_position_embeddings
73+
74+
"""
75+
Properties inherited from `DSTransformerModelBase`
76+
"""
77+
78+
@property
79+
def num_layers(self) -> int:
80+
return self._config.num_hidden_layers
81+
82+
@property
83+
def model_dim(self) -> int:
84+
return self._config.hidden_size
85+
86+
@property
87+
def vocab_size(self) -> int:
88+
return self._config.vocab_size
89+
90+
@property
91+
def head_size(self) -> int:
92+
return getattr(self._config, 'head_dim', self.model_dim // self.n_heads)
93+
94+
@property
95+
def n_heads(self) -> int:
96+
return self._config.num_attention_heads
97+
98+
@property
99+
def intermediate_dim(self) -> int:
100+
return self._config.intermediate_size
101+
102+
@property
103+
def n_heads_kv(self) -> int:
104+
return self._config.num_key_value_heads
105+
106+
@property
107+
def activation_dtype(self) -> DtypeEnum:
108+
if self._config.torch_dtype == torch.float16:
109+
return DtypeEnum.fp16
110+
elif self._config.torch_dtype == torch.bfloat16:
111+
return DtypeEnum.bf16
112+
else:
113+
raise NotImplementedError("Only fp16 and bf16 are supported")
114+
115+
@property
116+
def mlp_activation_fn(self) -> ActivationType:
117+
activation = self._config.hidden_act.lower()
118+
# EXAONE 4.0 uses gated SiLU activation like LLaMA
119+
if activation == "silu":
120+
return ActivationType.SiGLU
121+
elif activation == "gelu":
122+
return ActivationType.GEGLU
123+
elif activation == "relu":
124+
return ActivationType.ReGLU
125+
else:
126+
raise NotImplementedError(f"Activation {activation} not supported")
127+
128+
@property
129+
def norm_type(self) -> NormTypeEnum:
130+
return NormTypeEnum.RMSNorm
131+
132+
@property
133+
def positional_embedding_type(self) -> PositionalEmbeddingType:
134+
return PositionalEmbeddingType.rotate_half
135+
136+
@property
137+
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
138+
return RotateHalfConfig(theta_base=self._config.rope_theta)
139+
140+
"""
141+
Helper methods for EXAONE 4.0 specific features
142+
"""
143+
144+
def is_global_attention_layer(self, layer_idx: int) -> bool:
145+
"""Check if layer uses global (full) attention vs local (sliding) attention"""
146+
if self._layer_types and layer_idx < len(self._layer_types):
147+
return self._layer_types[layer_idx] == 'full_attention'
148+
return False
149+
150+
def should_apply_rope(self, layer_idx: int) -> bool:
151+
"""EXAONE 4.0 skips RoPE for global attention layers"""
152+
return not self.is_global_attention_layer(layer_idx)
153+
154+
"""
155+
Forward implementations
156+
"""
157+
158+
def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
159+
"""
160+
Performs the embedding lookup prior to running the transformer of the model.
161+
162+
Arguments:
163+
ragged_batch (RaggedBatchWrapper): The batch to embed.
164+
165+
Returns:
166+
torch.Tensor: The embedded batch.
167+
"""
168+
embed = self.embed(ragged_batch, self._non_transformer.word_emb)
169+
170+
if embed.shape[-1] != self.model_dim:
171+
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")
172+
173+
return embed
174+
175+
def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
176+
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
177+
"""
178+
Executes one transformer layer with EXAONE 4.0 specific features:
179+
- Hybrid attention (sliding vs full)
180+
- QK-Reorder-Norm (RMSNorm after Q/K projections)
181+
- Conditional RoPE (skip for global layers)
182+
183+
Arguments:
184+
layer_idx (int): The index of the layer to execute.
185+
residual (torch.Tensor): The residual tensor from the previous layer.
186+
hidden_states (torch.Tensor): The hidden states from the previous layer.
187+
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
188+
"""
189+
cur_params = self._transformer[layer_idx]
190+
kv_cache = self.state_manager.get_cache(layer_idx)
191+
192+
# QKV projection
193+
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None)
194+
195+
# EXAONE 4.0 attention with hybrid pattern and conditional RoPE
196+
# NOTE: The attention module should handle QK-Reorder-Norm internally
197+
# and respect the RoPE configuration based on layer type
198+
if self.is_global_attention_layer(layer_idx):
199+
# Global attention: full attention, no RoPE
200+
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info, apply_rotary_pos_emb=False)
201+
else:
202+
# Local attention: sliding window, with RoPE
203+
hidden_states = self.attn(hidden_states,
204+
kv_cache,
205+
ragged_batch_info,
206+
apply_rotary_pos_emb=True,
207+
sliding_window=getattr(self._config, 'sliding_window', 4096))
208+
209+
# Attention output projection
210+
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)
211+
212+
if self.tp_size > 1:
213+
dist.all_reduce(hidden_states, group=self._base_mp_group)
214+
215+
# Post-attention normalization
216+
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)
217+
218+
# MLP forward pass (gated SiLU)
219+
hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
220+
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)
221+
222+
if self.tp_size > 1:
223+
dist.all_reduce(hidden_states, group=self._base_mp_group)
224+
225+
# Prepare for next layer normalization
226+
if layer_idx != self.num_layers - 1:
227+
next_params = self._transformer[layer_idx + 1]
228+
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
229+
else:
230+
# On last layer, just perform the residual add
231+
residual.add_(hidden_states)
232+
233+
return residual, hidden_states
234+
235+
def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
236+
"""
237+
Performs unembedding of the hidden states to logits. This will only sample the final
238+
token of each sequence.
239+
"""
240+
logits = self.unembed(hidden_states,
241+
self._non_transformer.word_unembed,
242+
ragged_batch_info,
243+
gamma=self._non_transformer.final_norm)
244+
245+
if self.tp_size > 1:
246+
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
247+
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))
248+
249+
dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)
250+
251+
full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))
252+
253+
return full_logits
254+
else:
255+
return logits
256+
257+
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
258+
"""
259+
Forward pass for EXAONE 4.0 model with hybrid attention support.
260+
"""
261+
residual = self._forward_embed(wrapped_batch)
262+
263+
# Initial normalization
264+
residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None)
265+
266+
# Forward through all transformer layers
267+
for layer_idx in range(self.num_layers):
268+
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
269+
wrapped_batch)
270+
271+
return self._forward_unembed(residual, wrapped_batch)

0 commit comments

Comments
 (0)