1717from accelerate import init_empty_weights
1818from nemo .lightning import io
1919from torch import nn
20+ from transformers import EsmConfig , EsmForMaskedLM
2021
2122from esm .modeling_esm_te import NVEsmConfig , NVEsmForMaskedLM
2223
4041 "lm_head.layer_norm.bias" : "lm_head.decoder.layer_norm_bias" ,
4142}
4243
44+ # Reverse mapping from TE to HF format by reversing the original mapping
45+ reverse_mapping = {v : k for k , v in mapping .items ()}
46+
4347
4448def convert_esm_hf_to_te (model_hf : nn .Module , ** config_kwargs ) -> nn .Module :
4549 """Convert a Hugging Face model to a Transformer Engine model.
@@ -69,6 +73,70 @@ def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
6973 return output_model
7074
7175
76+ def convert_esm_te_to_hf (model_te : nn .Module , ** config_kwargs ) -> nn .Module :
77+ """Convert a Transformer Engine model back to the original HuggingFace Facebook ESM-2 format.
78+
79+ This function converts from the NVIDIA Transformer Engine (TE) format back to the
80+ weight format compatible with the original facebook/esm2_* series of checkpoints.
81+ The TE model is also a HuggingFace model, but this conversion ensures compatibility
82+ with the original Facebook ESM-2 model architecture and weight format hosted on Hugging Face.
83+
84+ Args:
85+ model_te (nn.Module): The Transformer Engine model.
86+ **config_kwargs: Additional configuration kwargs to be passed to EsmConfig.
87+
88+ Returns:
89+ nn.Module: The Hugging Face model in original Facebook ESM-2 format hosted on Hugging Face.
90+ """
91+ # Convert TE config to HF config
92+ hf_config_dict = model_te .config .to_dict ()
93+
94+ # Remove TE-specific config options
95+ te_specific_keys = [
96+ "qkv_weight_interleaved" ,
97+ "encoder_activation" ,
98+ "attn_input_format" ,
99+ "fuse_qkv_params" ,
100+ "micro_batch_size" ,
101+ "max_seq_length" ,
102+ "model_type" ,
103+ "auto_map" ,
104+ ]
105+ for key in te_specific_keys :
106+ hf_config_dict .pop (key , None )
107+
108+ hf_config_dict ["model_type" ] = "esm"
109+
110+ hf_config = EsmConfig (** hf_config_dict , ** config_kwargs )
111+
112+ with init_empty_weights ():
113+ model_hf = EsmForMaskedLM (hf_config )
114+
115+ # Remove contact_head since it's not present in TE models
116+ if hasattr (model_hf .esm , "contact_head" ):
117+ delattr (model_hf .esm , "contact_head" )
118+
119+ output_model = io .apply_transforms (
120+ model_te ,
121+ model_hf ,
122+ reverse_mapping ,
123+ [_unpack_qkv_weight , _unpack_qkv_bias , _unpad_embeddings , _unpad_decoder_weights , _unpad_bias ],
124+ state_dict_ignored_entries = [
125+ "lm_head.decoder.weight" ,
126+ "esm.contact_head.regression.weight" ,
127+ "esm.contact_head.regression.bias" ,
128+ ],
129+ )
130+
131+ output_model .tie_weights ()
132+
133+ # Note: contact_head parameters are not preserved in TE models
134+ # They are lost during HF -> TE conversion and cannot be recovered
135+ # The converted model will not have the original contact_head weights
136+
137+ return output_model
138+
139+
72140@io .state_transform (
73141 source_key = (
74142 "esm.encoder.layer.*.attention.self.query.weight" ,
@@ -81,11 +149,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
81149 """Pad the embedding layer to the new input dimension."""
82150 concat_weights = torch .cat ((query , key , value ), dim = 0 )
83151 input_shape = concat_weights .size ()
84- np = ctx .target .config .num_attention_heads
152+ num_heads = ctx .target .config .num_attention_heads
85153 # transpose weights
86154 # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
87155 # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
88- concat_weights = concat_weights .view (3 , np , - 1 , query .size ()[- 1 ])
156+ concat_weights = concat_weights .view (3 , num_heads , - 1 , query .size ()[- 1 ])
89157 concat_weights = concat_weights .transpose (0 , 1 ).contiguous ()
90158 concat_weights = concat_weights .view (* input_shape )
91159 return concat_weights
@@ -103,16 +171,78 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
103171 """Pad the embedding layer to the new input dimension."""
104172 concat_biases = torch .cat ((query , key , value ), dim = 0 )
105173 input_shape = concat_biases .size ()
106- np = ctx .target .config .num_attention_heads
174+ num_heads = ctx .target .config .num_attention_heads
107175 # transpose biases
108176 # [num_splits_model_parallel * attention head size * #attention heads]
109177 # --> [attention head size * num_splits_model_parallel * #attention heads]
110- concat_biases = concat_biases .view (3 , np , - 1 )
178+ concat_biases = concat_biases .view (3 , num_heads , - 1 )
111179 concat_biases = concat_biases .transpose (0 , 1 ).contiguous ()
112180 concat_biases = concat_biases .view (* input_shape )
113181 return concat_biases
114182
115183
184+ @io .state_transform (
185+ source_key = "esm.encoder.layers.*.self_attention.layernorm_qkv.weight" ,
186+ target_key = (
187+ "esm.encoder.layer.*.attention.self.query.weight" ,
188+ "esm.encoder.layer.*.attention.self.key.weight" ,
189+ "esm.encoder.layer.*.attention.self.value.weight" ,
190+ ),
191+ )
192+ def _unpack_qkv_weight (ctx : io .TransformCTX , qkv_weight ):
193+ """Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value."""
194+ num_heads = ctx .source .config .num_attention_heads
195+ total_rows , input_dim = qkv_weight .size () # size: [num_heads * 3 *head_dim, input_dim]
196+ assert total_rows % (3 * num_heads ) == 0 , (
197+ f"QKV weight rows { total_rows } not divisible by 3*num_heads { 3 * num_heads } "
198+ )
199+ head_dim = total_rows // (3 * num_heads )
200+
201+ qkv_weight = (
202+ qkv_weight .view (num_heads , 3 , head_dim , input_dim ).transpose (0 , 1 ).contiguous ()
203+ ) # size: [3, num_heads, head_dim, input_dim]
204+ query , key , value = qkv_weight [0 ], qkv_weight [1 ], qkv_weight [2 ] # size: [num_heads, head_dim, input_dim]
205+
206+ query = query .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
207+ key = key .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
208+ value = value .reshape (- 1 , input_dim ) # size: [num_heads * head_dim, input_dim]
209+
210+ return query , key , value
211+
212+
213+ @io .state_transform (
214+ source_key = "esm.encoder.layers.*.self_attention.layernorm_qkv.bias" ,
215+ target_key = (
216+ "esm.encoder.layer.*.attention.self.query.bias" ,
217+ "esm.encoder.layer.*.attention.self.key.bias" ,
218+ "esm.encoder.layer.*.attention.self.value.bias" ,
219+ ),
220+ )
221+ def _unpack_qkv_bias (ctx : io .TransformCTX , qkv_bias ):
222+ """Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value."""
223+ num_heads = ctx .source .config .num_attention_heads
224+ total_size = qkv_bias .size (0 ) # size: [num_heads * 3 * head_dim]
225+ assert total_size % (3 * num_heads ) == 0 , (
226+ f"QKV bias size { total_size } not divisible by 3*num_heads { 3 * num_heads } "
227+ )
228+ head_dim = total_size // (3 * num_heads )
229+
230+ qkv_bias = qkv_bias .view (num_heads , 3 , head_dim ).transpose (0 , 1 ).contiguous () # size: [3, num_heads, head_dim]
231+ query , key , value = qkv_bias [0 ], qkv_bias [1 ], qkv_bias [2 ] # size: [num_heads, head_dim]
232+
233+ query = query .reshape (- 1 ) # size: [num_heads * head_dim]
234+ key = key .reshape (- 1 ) # size: [num_heads * head_dim]
235+ value = value .reshape (- 1 ) # size: [num_heads * head_dim]
236+
237+ return query , key , value
238+
239+
240+ def _unpad_weights (ctx : io .TransformCTX , padded_embed ):
241+ """Remove padding from the embedding layer to get back to the original dimension."""
242+ target_embedding_dimension = ctx .target .config .vocab_size
243+ return padded_embed [:target_embedding_dimension ]
244+
245+
116246def _pad_weights (ctx : io .TransformCTX , source_embed ):
117247 """Pad the embedding layer to the new input dimension."""
118248 target_embedding_dimension = ctx .target .config .padded_vocab_size
@@ -134,6 +264,16 @@ def _pad_weights(ctx: io.TransformCTX, source_embed):
134264 target_key = "lm_head.decoder.weight" ,
135265)(_pad_weights )
136266
267+ _unpad_embeddings = io .state_transform (
268+ source_key = "esm.embeddings.word_embeddings.weight" ,
269+ target_key = "esm.embeddings.word_embeddings.weight" ,
270+ )(_unpad_weights )
271+
272+ _unpad_decoder_weights = io .state_transform (
273+ source_key = "lm_head.decoder.weight" ,
274+ target_key = "lm_head.decoder.weight" ,
275+ )(_unpad_weights )
276+
137277
138278@io .state_transform (
139279 source_key = "lm_head.bias" ,
@@ -148,3 +288,13 @@ def _pad_bias(ctx: io.TransformCTX, source_bias):
148288 )
149289 output_bias [:hf_embedding_dimension ] = source_bias
150290 return output_bias
291+
292+
293+ @io .state_transform (
294+ source_key = "lm_head.decoder.bias" ,
295+ target_key = "lm_head.bias" ,
296+ )
297+ def _unpad_bias (ctx : io .TransformCTX , padded_bias ):
298+ """Remove padding from the bias to get back to the original dimension."""
299+ target_embedding_dimension = ctx .target .config .vocab_size
300+ return padded_bias [:target_embedding_dimension ]
0 commit comments