25
25
from vllm .v1 .sample .metadata import SamplingMetadata
26
26
27
27
from .utils import AutoWeightsLoader , maybe_prefix
28
- from .llama_eagle3 import LlamaModel as LlamaEagle3Model
29
28
30
29
logger = init_logger (__name__ )
31
30
@@ -87,9 +86,7 @@ def forward(
87
86
88
87
89
88
@support_torch_compile
90
- class Eagle3HunYuanModel (LlamaEagle3Model ):
91
- # Most function are same as Llama Eagle 3 support.
92
- # only different is from init layer.
89
+ class Eagle3HunYuanModel (nn .Module ):
93
90
94
91
def __init__ (
95
92
self ,
@@ -98,8 +95,7 @@ def __init__(
98
95
start_layer_id : int = 0 ,
99
96
prefix : str = "" ,
100
97
) -> None :
101
- # llama 's init will setup layers, which cuase conflict
102
- nn .Module .__init__ (self )
98
+ super ().__init__ ()
103
99
self .config = vllm_config . \
104
100
speculative_config .draft_model_config .hf_config
105
101
self .vocab_size = self .config .vocab_size
@@ -130,7 +126,59 @@ def __init__(
130
126
eps = self .config .rms_norm_eps ,
131
127
)
132
128
133
- class Eagle3HunYuanDenseV1ForCausalLM (HunYuanDenseV1ForCausalLM ):
129
+ def forward (
130
+ self ,
131
+ input_ids : torch .Tensor ,
132
+ positions : torch .Tensor ,
133
+ hidden_states : torch .Tensor ,
134
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
135
+ input_embeds = self .embed_tokens (input_ids )
136
+ assert hidden_states .shape [- 1 ] == input_embeds .shape [- 1 ]
137
+
138
+ residual = None
139
+ hidden_states , residual = self .layers [0 ](
140
+ positions ,
141
+ input_embeds ,
142
+ hidden_states ,
143
+ residual ,
144
+ )
145
+
146
+ hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
147
+ return hidden_states , hidden_prenorm
148
+
149
+ def load_weights (self , weights : Iterable [tuple [str ,
150
+ torch .Tensor ]]) -> set [str ]:
151
+ stacked_params_mapping = [
152
+ # (param_name, shard_name, shard_id)
153
+ (".qkv_proj" , ".q_proj" , "q" ),
154
+ (".qkv_proj" , ".k_proj" , "k" ),
155
+ (".qkv_proj" , ".v_proj" , "v" ),
156
+ (".gate_up_proj" , ".gate_proj" , 0 ),
157
+ (".gate_up_proj" , ".up_proj" , 1 ),
158
+ ]
159
+ params_dict = dict (self .named_parameters ())
160
+ loaded_params : set [str ] = set ()
161
+ for name , loaded_weight in weights :
162
+ if 'midlayer.' in name :
163
+ name = name .replace ('midlayer.' , 'layers.0.' )
164
+ for param_name , weight_name , shard_id in stacked_params_mapping :
165
+ if weight_name not in name :
166
+ continue
167
+ name = name .replace (weight_name , param_name )
168
+ param = params_dict [name ]
169
+ weight_loader = param .weight_loader
170
+ weight_loader (param , loaded_weight , shard_id )
171
+ break
172
+ else :
173
+ param = params_dict [name ]
174
+ weight_loader = getattr (param , "weight_loader" ,
175
+ default_weight_loader )
176
+ weight_loader (param , loaded_weight )
177
+ loaded_params .add (name )
178
+ return loaded_params
179
+
180
+
181
+ class Eagle3HunYuanDenseV1ForCausalLM (nn .Module ):
134
182
135
183
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
136
184
nn .Module .__init__ (self )
0 commit comments