@@ -103,10 +103,13 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
103103 "micro_batch_size" ,
104104 "max_seq_length" ,
105105 "model_type" ,
106+ "auto_map" ,
106107 ]
107108 for key in te_specific_keys :
108109 hf_config_dict .pop (key , None )
109110
111+ hf_config_dict ["model_type" ] = "esm"
112+
110113 hf_config = EsmConfig (** hf_config_dict , ** config_kwargs )
111114
112115 with init_empty_weights ():
@@ -149,11 +152,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
149152 """Pad the embedding layer to the new input dimension."""
150153 concat_weights = torch .cat ((query , key , value ), dim = 0 )
151154 input_shape = concat_weights .size ()
152- np = ctx .target .config .num_attention_heads
155+ num_heads = ctx .target .config .num_attention_heads
153156 # transpose weights
154157 # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
155158 # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
156- concat_weights = concat_weights .view (3 , np , - 1 , query .size ()[- 1 ])
159+ concat_weights = concat_weights .view (3 , num_heads , - 1 , query .size ()[- 1 ])
157160 concat_weights = concat_weights .transpose (0 , 1 ).contiguous ()
158161 concat_weights = concat_weights .view (* input_shape )
159162 return concat_weights
@@ -171,11 +174,11 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
171174 """Pad the embedding layer to the new input dimension."""
172175 concat_biases = torch .cat ((query , key , value ), dim = 0 )
173176 input_shape = concat_biases .size ()
174- np = ctx .target .config .num_attention_heads
177+ num_heads = ctx .target .config .num_attention_heads
175178 # transpose biases
176179 # [num_splits_model_parallel * attention head size * #attention heads]
177180 # --> [attention head size * num_splits_model_parallel * #attention heads]
178- concat_biases = concat_biases .view (3 , np , - 1 )
181+ concat_biases = concat_biases .view (3 , num_heads , - 1 )
179182 concat_biases = concat_biases .transpose (0 , 1 ).contiguous ()
180183 concat_biases = concat_biases .view (* input_shape )
181184 return concat_biases
@@ -190,26 +193,20 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
190193 ),
191194)
192195def _unpack_qkv_weight (ctx : io .TransformCTX , qkv_weight ):
193- """Unpack the fused QKV weight into separate query, key, and value weights."""
194- np = ctx .source .config .num_attention_heads
195-
196- # Reverse the packing transformation
197- # First, reshape to separate the interleaved Q, K, V
198- # [attention head size * num_splits_model_parallel * #attention heads]
199- # --> [num_splits_model_parallel * attention head size * #attention heads]
200- qkv_weight = qkv_weight .view (np , 3 , - 1 , qkv_weight .size ()[- 1 ]) # Output:[num_heads, 3, head_dim, vocab_size]
201- qkv_weight = qkv_weight .transpose (0 , 1 ).contiguous () # Output:[3, num_heads, head_dim, vocab_size]
202-
203- # Split into Q, K, V directly from the transposed tensor
204- # qkv_weight shape: [3, num_heads, head_dim, input_dim]
205- query = qkv_weight [0 ] # [num_heads, head_dim, input_dim]
206- key = qkv_weight [1 ] # [num_heads, head_dim, input_dim]
207- value = qkv_weight [2 ] # [num_heads, head_dim, input_dim]
208-
209- # Reshape to match HF format: [total_head_dim, input_dim]
210- query = query .view (- 1 , query .size ()[- 1 ]) # [num_heads * head_dim, input_dim]
211- key = key .view (- 1 , key .size ()[- 1 ]) # [num_heads * head_dim, input_dim]
212- value = value .view (- 1 , value .size ()[- 1 ]) # [num_heads * head_dim, input_dim]
196+ """Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value."""
197+ num_heads = ctx .source .config .num_attention_heads
198+ total_rows , input_dim = qkv_weight .size () # size: [num_heads * 3 *head_dim, input_dim]
199+ assert total_rows % (3 * num_heads ) == 0 , (
200+ f"QKV weight rows { total_rows } not divisible by 3*num_heads { 3 * num_heads } "
201+ )
202+ head_dim = total_rows // (3 * num_heads )
203+
204+ qkv_weight = qkv_weight .view (num_heads , 3 , head_dim , input_dim ).transpose (0 , 1 ).contiguous () # size: [3, num_heads, head_dim, input_dim]
205+ query , key , value = qkv_weight [0 ], qkv_weight [1 ], qkv_weight [2 ] # size: [num_heads, head_dim, input_dim]
206+
207+ query = query .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
208+ key = key .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
209+ value = value .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
213210
214211 return query , key , value
215212
@@ -223,25 +220,19 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):
223220 ),
224221)
225222def _unpack_qkv_bias (ctx : io .TransformCTX , qkv_bias ):
226- """Unpack the fused QKV bias into separate query, key, and value biases."""
227- np = ctx .source .config .num_attention_heads
223+ """Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value."""
224+ num_heads = ctx .source .config .num_attention_heads
225+ total_size = qkv_bias .size (0 ) # size: [num_heads * 3 * head_dim]
226+ assert total_size % (3 * num_heads ) == 0 , (
227+ f"QKV bias size { total_size } not divisible by 3*num_heads { 3 * num_heads } "
228+ )
229+ head_dim = total_size // (3 * num_heads )
228230
229- # Reverse the packing transformation
230- # First, reshape to separate the interleaved Q, K, V
231- # [num_splits_model_parallel * attention head size * #attention heads]
232- # --> [attention head size * num_splits_model_parallel * #attention heads]
233- qkv_bias = qkv_bias .view (np , 3 , - 1 )
234- qkv_bias = qkv_bias .transpose (0 , 1 ).contiguous ()
235-
236- # Split into Q, K, V directly from the transposed tensor
237- # qkv_bias shape: [3, num_heads, head_dim]
238- query = qkv_bias [0 ] # [num_heads, head_dim]
239- key = qkv_bias [1 ] # [num_heads, head_dim]
240- value = qkv_bias [2 ] # [num_heads, head_dim]
241-
242- # Reshape to match HF format: [total_head_dim]
243- query = query .view (- 1 ) # [num_heads * head_dim]
244- key = key .view (- 1 ) # [num_heads * head_dim]
245- value = value .view (- 1 ) # [num_heads * head_dim]
231+ qkv_bias = qkv_bias .view (num_heads , 3 , head_dim ).transpose (0 , 1 ).contiguous () # size: [3, num_heads, head_dim]
232+ query , key , value = qkv_bias [0 ], qkv_bias [1 ], qkv_bias [2 ] # size: [num_heads, head_dim]
233+
234+ query = query .reshape (- 1 ) # size: [num_heads * head_dim]
235+ key = key .reshape (- 1 ) # size: [num_heads * head_dim]
236+ value = value .reshape (- 1 ) # size: [num_heads * head_dim]
246237
247238 return query , key , value
0 commit comments