@@ -103,10 +103,13 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
103103        "micro_batch_size" ,
104104        "max_seq_length" ,
105105        "model_type" ,
106+         "auto_map" ,
106107    ]
107108    for  key  in  te_specific_keys :
108109        hf_config_dict .pop (key , None )
109110
111+     hf_config_dict ["model_type" ] =  "esm" 
112+ 
110113    hf_config  =  EsmConfig (** hf_config_dict , ** config_kwargs )
111114
112115    with  init_empty_weights ():
@@ -149,11 +152,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
149152    """Pad the embedding layer to the new input dimension.""" 
150153    concat_weights  =  torch .cat ((query , key , value ), dim = 0 )
151154    input_shape  =  concat_weights .size ()
152-     np  =  ctx .target .config .num_attention_heads 
155+     num_heads  =  ctx .target .config .num_attention_heads 
153156    # transpose weights 
154157    # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] 
155158    # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] 
156-     concat_weights  =  concat_weights .view (3 , np , - 1 , query .size ()[- 1 ])
159+     concat_weights  =  concat_weights .view (3 , num_heads , - 1 , query .size ()[- 1 ])
157160    concat_weights  =  concat_weights .transpose (0 , 1 ).contiguous ()
158161    concat_weights  =  concat_weights .view (* input_shape )
159162    return  concat_weights 
@@ -171,11 +174,11 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
171174    """Pad the embedding layer to the new input dimension.""" 
172175    concat_biases  =  torch .cat ((query , key , value ), dim = 0 )
173176    input_shape  =  concat_biases .size ()
174-     np  =  ctx .target .config .num_attention_heads 
177+     num_heads  =  ctx .target .config .num_attention_heads 
175178    # transpose biases 
176179    # [num_splits_model_parallel * attention head size * #attention heads] 
177180    # --> [attention head size * num_splits_model_parallel * #attention heads] 
178-     concat_biases  =  concat_biases .view (3 , np , - 1 )
181+     concat_biases  =  concat_biases .view (3 , num_heads , - 1 )
179182    concat_biases  =  concat_biases .transpose (0 , 1 ).contiguous ()
180183    concat_biases  =  concat_biases .view (* input_shape )
181184    return  concat_biases 
@@ -190,26 +193,20 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
190193    ), 
191194) 
192195def  _unpack_qkv_weight (ctx : io .TransformCTX , qkv_weight ):
193-     """Unpack the fused QKV weight into separate query, key, and value weights.""" 
194-     np  =  ctx .source .config .num_attention_heads 
195- 
196-     # Reverse the packing transformation 
197-     # First, reshape to separate the interleaved Q, K, V 
198-     # [attention head size * num_splits_model_parallel * #attention heads] 
199-     # --> [num_splits_model_parallel * attention head size * #attention heads] 
200-     qkv_weight  =  qkv_weight .view (np , 3 , - 1 , qkv_weight .size ()[- 1 ])  # Output:[num_heads, 3, head_dim, vocab_size] 
201-     qkv_weight  =  qkv_weight .transpose (0 , 1 ).contiguous ()  # Output:[3, num_heads, head_dim, vocab_size] 
202- 
203-     # Split into Q, K, V directly from the transposed tensor 
204-     # qkv_weight shape: [3, num_heads, head_dim, input_dim] 
205-     query  =  qkv_weight [0 ]  # [num_heads, head_dim, input_dim] 
206-     key  =  qkv_weight [1 ]  # [num_heads, head_dim, input_dim] 
207-     value  =  qkv_weight [2 ]  # [num_heads, head_dim, input_dim] 
208- 
209-     # Reshape to match HF format: [total_head_dim, input_dim] 
210-     query  =  query .view (- 1 , query .size ()[- 1 ])  # [num_heads * head_dim, input_dim] 
211-     key  =  key .view (- 1 , key .size ()[- 1 ])  # [num_heads * head_dim, input_dim] 
212-     value  =  value .view (- 1 , value .size ()[- 1 ])  # [num_heads * head_dim, input_dim] 
196+     """Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value.""" 
197+     num_heads  =  ctx .source .config .num_attention_heads 
198+     total_rows , input_dim  =  qkv_weight .size () # size: [num_heads * 3 *head_dim, input_dim] 
199+     assert  total_rows  %  (3  *  num_heads ) ==  0 , (
200+         f"QKV weight rows { total_rows }   not divisible by 3*num_heads { 3 * num_heads }  " 
201+     )
202+     head_dim  =  total_rows  //  (3  *  num_heads )
203+ 
204+     qkv_weight  =  qkv_weight .view (num_heads , 3 , head_dim , input_dim ).transpose (0 , 1 ).contiguous () # size: [3, num_heads, head_dim, input_dim] 
205+     query , key , value  =  qkv_weight [0 ], qkv_weight [1 ], qkv_weight [2 ] # size: [num_heads, head_dim, input_dim] 
206+ 
207+     query  =  query .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim] 
208+     key  =  key .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim] 
209+     value  =  value .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim] 
213210
214211    return  query , key , value 
215212
@@ -223,25 +220,19 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):
223220    ), 
224221) 
225222def  _unpack_qkv_bias (ctx : io .TransformCTX , qkv_bias ):
226-     """Unpack the fused QKV bias into separate query, key, and value biases.""" 
227-     np  =  ctx .source .config .num_attention_heads 
223+     """Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value.""" 
224+     num_heads  =  ctx .source .config .num_attention_heads 
225+     total_size  =  qkv_bias .size (0 ) # size: [num_heads * 3 * head_dim] 
226+     assert  total_size  %  (3  *  num_heads ) ==  0 , (
227+         f"QKV bias size { total_size }   not divisible by 3*num_heads { 3 * num_heads }  " 
228+     )
229+     head_dim  =  total_size  //  (3  *  num_heads )
228230
229-     # Reverse the packing transformation 
230-     # First, reshape to separate the interleaved Q, K, V 
231-     # [num_splits_model_parallel * attention head size * #attention heads] 
232-     # --> [attention head size * num_splits_model_parallel * #attention heads] 
233-     qkv_bias  =  qkv_bias .view (np , 3 , - 1 )
234-     qkv_bias  =  qkv_bias .transpose (0 , 1 ).contiguous ()
235- 
236-     # Split into Q, K, V directly from the transposed tensor 
237-     # qkv_bias shape: [3, num_heads, head_dim] 
238-     query  =  qkv_bias [0 ]  # [num_heads, head_dim] 
239-     key  =  qkv_bias [1 ]  # [num_heads, head_dim] 
240-     value  =  qkv_bias [2 ]  # [num_heads, head_dim] 
241- 
242-     # Reshape to match HF format: [total_head_dim] 
243-     query  =  query .view (- 1 )  # [num_heads * head_dim] 
244-     key  =  key .view (- 1 )  # [num_heads * head_dim] 
245-     value  =  value .view (- 1 )  # [num_heads * head_dim] 
231+     qkv_bias  =  qkv_bias .view (num_heads , 3 , head_dim ).transpose (0 , 1 ).contiguous () # size: [3, num_heads, head_dim] 
232+     query , key , value  =  qkv_bias [0 ], qkv_bias [1 ], qkv_bias [2 ] # size: [num_heads, head_dim] 
233+ 
234+     query  =  query .reshape (- 1 ) # size: [num_heads * head_dim] 
235+     key  =  key .reshape (- 1 ) # size: [num_heads * head_dim] 
236+     value  =  value .reshape (- 1 ) # size: [num_heads * head_dim] 
246237
247238    return  query , key , value 
0 commit comments