Skip to content

Commit de8bfff

Browse files
authored
Ohad's changes to convert and test_convert for TE > HF checkpoint conversion (#1218)
Adds convert.py and test_convert.py from @ohadmo's PR and makes some readme updates Takes many changes from #1121 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added bidirectional conversion between Transformer Engine and HuggingFace ESM-2 formats, including TE→HF roundtrip support. * **Documentation** * Expanded ESM-2 guide with Python-centric conversion examples, "Load and Test" and "Validating Converted Models" sections, and step-by-step deployment/upload workflows. * **Tests** * Added end-to-end and unit tests validating round-trip conversions, parameter parity, config alignment, and padding/unpadding behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <[email protected]>
1 parent 228f5a2 commit de8bfff

File tree

3 files changed

+345
-19
lines changed

3 files changed

+345
-19
lines changed

bionemo-recipes/models/esm2/README.md

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The ESM-2 implementation natively supports the following TransformerEngine-provi
1616
| **Sequence Packing / THD input format** | ✅ Supported |
1717
| **FP8 with THD input format** | ✅ Supported where FP8 is supported |
1818
| **Import from HuggingFace checkpoints** | ✅ Supported |
19-
| **Export to HuggingFace checkpoints** | 🚧 Under development |
19+
| **Export to HuggingFace checkpoints** | ✅ Supported |
2020

2121
See [BioNemo Recipes](../../recipes/README.md) for more details on how to use these features to accelerate model
2222
training and inference.
@@ -70,24 +70,53 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory:
7070
- **[esm2_accelerate_te](../../recipes/esm2_accelerate_te/)** - Trains the model using HuggingFace
7171
[Accelerate](https://huggingface.co/docs/accelerate/index).
7272

73-
## Commands for converting checkpoints
73+
## Converting Between Model Formats
7474

75-
### HF Transformers to TE conversion
75+
This section explains how to convert between Hugging Face Transformers and Transformer Engine (TE) ESM2 model formats.
76+
The process demonstrates bidirectional conversion: from Transformers to TE format for optimized inference, and back to
77+
Hugging Face Transformers format for sharing and deployment. The workflow involves several key steps:
7678

77-
Generate converted ESM-2 checkpoints from existing HuggingFace transformers checkpoints:
79+
### Converting from HF Transformers to TE
7880

79-
```bash
80-
mkdir -p checkpoint_export
81-
docker build -t esm2 .
82-
docker run --rm -it --gpus all \
83-
-v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \
84-
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
85-
esm2 python export.py
81+
```python
82+
from transformers import AutoModelForMaskedLM
83+
84+
from esm.convert import convert_esm_hf_to_te
85+
86+
hf_model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
87+
te_model = convert_esm_hf_to_te(hf_model)
88+
te_model.save_pretrained("/path/to/te_checkpoint")
89+
```
90+
91+
This loads the pre-trained ESM2 model that will serve as our reference for comparison.
92+
93+
### Converting from TE back to HF Transformers
94+
95+
```python
96+
from esm.convert import convert_esm_te_to_hf
97+
from esm.modeling_esm_te import NVEsmForMaskedLM
98+
99+
te_model = NVEsmForMaskedLM.from_pretrained("/path/to/te_checkpoint")
100+
hf_model = convert_esm_te_to_hf(te_model)
101+
hf_model.save_pretrained("/path/to/hf_checkpoint")
86102
```
87103

88-
### TE to HF Transformers conversion
104+
Load and Test the Exported Model
89105

90-
(Coming soon)
106+
Load the exported model and perform validation:
107+
108+
```python
109+
from transformers import AutoTokenizer
110+
111+
model_hf_exported = AutoModelForMaskedLM.from_pretrained("/path/to/hf_checkpoint")
112+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
113+
```
114+
115+
### Validating Converted Models
116+
117+
See the commands in [Inference Examples](#inference-examples) above to load and test both the original and converted
118+
models to ensure loss and logit values are similar. See also the golden value tests in
119+
[test_modeling_esm_te.py](tests/test_modeling_esm_te.py) and [test_convert.py](tests/test_convert.py).
91120

92121
## Developer Guide
93122

@@ -107,8 +136,18 @@ editable mode with `pip install -e .`, then run `pytest -v .` in the model direc
107136

108137
### Deploying converted checkpoints to HuggingFace Hub
109138

110-
After running the checkpoint conversion steps listed in [Commands for converting checkpoints](#commands-for-converting-checkpoints),
111-
you can deploy the converted checkpoints to the HuggingFace Hub by running the following command:
139+
First, generate converted ESM-2 checkpoints from existing HuggingFace transformers checkpoints:
140+
141+
```bash
142+
mkdir -p checkpoint_export
143+
docker build -t esm2 .
144+
docker run --rm -it --gpus all \
145+
-v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \
146+
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
147+
esm2 python export.py
148+
```
149+
150+
Now deploy the converted checkpoints to the HuggingFace Hub by running the following command for each model:
112151

113152
```bash
114153
huggingface-cli upload nvidia/${MODEL_NAME} $PWD/checkpoint_export/${MODEL_NAME}

bionemo-recipes/models/esm2/src/esm/convert.py

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from accelerate import init_empty_weights
1818
from nemo.lightning import io
1919
from torch import nn
20+
from transformers import EsmConfig, EsmForMaskedLM
2021

2122
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
2223

@@ -40,6 +41,9 @@
4041
"lm_head.layer_norm.bias": "lm_head.decoder.layer_norm_bias",
4142
}
4243

44+
# Reverse mapping from TE to HF format by reversing the original mapping
45+
reverse_mapping = {v: k for k, v in mapping.items()}
46+
4347

4448
def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
4549
"""Convert a Hugging Face model to a Transformer Engine model.
@@ -69,6 +73,70 @@ def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
6973
return output_model
7074

7175

76+
def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
77+
"""Convert a Transformer Engine model back to the original HuggingFace Facebook ESM-2 format.
78+
79+
This function converts from the NVIDIA Transformer Engine (TE) format back to the
80+
weight format compatible with the original facebook/esm2_* series of checkpoints.
81+
The TE model is also a HuggingFace model, but this conversion ensures compatibility
82+
with the original Facebook ESM-2 model architecture and weight format hosted on Hugging Face.
83+
84+
Args:
85+
model_te (nn.Module): The Transformer Engine model.
86+
**config_kwargs: Additional configuration kwargs to be passed to EsmConfig.
87+
88+
Returns:
89+
nn.Module: The Hugging Face model in original Facebook ESM-2 format hosted on Hugging Face.
90+
"""
91+
# Convert TE config to HF config
92+
hf_config_dict = model_te.config.to_dict()
93+
94+
# Remove TE-specific config options
95+
te_specific_keys = [
96+
"qkv_weight_interleaved",
97+
"encoder_activation",
98+
"attn_input_format",
99+
"fuse_qkv_params",
100+
"micro_batch_size",
101+
"max_seq_length",
102+
"model_type",
103+
"auto_map",
104+
]
105+
for key in te_specific_keys:
106+
hf_config_dict.pop(key, None)
107+
108+
hf_config_dict["model_type"] = "esm"
109+
110+
hf_config = EsmConfig(**hf_config_dict, **config_kwargs)
111+
112+
with init_empty_weights():
113+
model_hf = EsmForMaskedLM(hf_config)
114+
115+
# Remove contact_head since it's not present in TE models
116+
if hasattr(model_hf.esm, "contact_head"):
117+
delattr(model_hf.esm, "contact_head")
118+
119+
output_model = io.apply_transforms(
120+
model_te,
121+
model_hf,
122+
reverse_mapping,
123+
[_unpack_qkv_weight, _unpack_qkv_bias, _unpad_embeddings, _unpad_decoder_weights, _unpad_bias],
124+
state_dict_ignored_entries=[
125+
"lm_head.decoder.weight",
126+
"esm.contact_head.regression.weight",
127+
"esm.contact_head.regression.bias",
128+
],
129+
)
130+
131+
output_model.tie_weights()
132+
133+
# Note: contact_head parameters are not preserved in TE models
134+
# They are lost during HF -> TE conversion and cannot be recovered
135+
# The converted model will not have the original contact_head weights
136+
137+
return output_model
138+
139+
72140
@io.state_transform(
73141
source_key=(
74142
"esm.encoder.layer.*.attention.self.query.weight",
@@ -81,11 +149,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
81149
"""Pad the embedding layer to the new input dimension."""
82150
concat_weights = torch.cat((query, key, value), dim=0)
83151
input_shape = concat_weights.size()
84-
np = ctx.target.config.num_attention_heads
152+
num_heads = ctx.target.config.num_attention_heads
85153
# transpose weights
86154
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
87155
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
88-
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
156+
concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1])
89157
concat_weights = concat_weights.transpose(0, 1).contiguous()
90158
concat_weights = concat_weights.view(*input_shape)
91159
return concat_weights
@@ -103,16 +171,78 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
103171
"""Pad the embedding layer to the new input dimension."""
104172
concat_biases = torch.cat((query, key, value), dim=0)
105173
input_shape = concat_biases.size()
106-
np = ctx.target.config.num_attention_heads
174+
num_heads = ctx.target.config.num_attention_heads
107175
# transpose biases
108176
# [num_splits_model_parallel * attention head size * #attention heads]
109177
# --> [attention head size * num_splits_model_parallel * #attention heads]
110-
concat_biases = concat_biases.view(3, np, -1)
178+
concat_biases = concat_biases.view(3, num_heads, -1)
111179
concat_biases = concat_biases.transpose(0, 1).contiguous()
112180
concat_biases = concat_biases.view(*input_shape)
113181
return concat_biases
114182

115183

184+
@io.state_transform(
185+
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight",
186+
target_key=(
187+
"esm.encoder.layer.*.attention.self.query.weight",
188+
"esm.encoder.layer.*.attention.self.key.weight",
189+
"esm.encoder.layer.*.attention.self.value.weight",
190+
),
191+
)
192+
def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):
193+
"""Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value."""
194+
num_heads = ctx.source.config.num_attention_heads
195+
total_rows, input_dim = qkv_weight.size() # size: [num_heads * 3 *head_dim, input_dim]
196+
assert total_rows % (3 * num_heads) == 0, (
197+
f"QKV weight rows {total_rows} not divisible by 3*num_heads {3 * num_heads}"
198+
)
199+
head_dim = total_rows // (3 * num_heads)
200+
201+
qkv_weight = (
202+
qkv_weight.view(num_heads, 3, head_dim, input_dim).transpose(0, 1).contiguous()
203+
) # size: [3, num_heads, head_dim, input_dim]
204+
query, key, value = qkv_weight[0], qkv_weight[1], qkv_weight[2] # size: [num_heads, head_dim, input_dim]
205+
206+
query = query.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
207+
key = key.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
208+
value = value.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
209+
210+
return query, key, value
211+
212+
213+
@io.state_transform(
214+
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias",
215+
target_key=(
216+
"esm.encoder.layer.*.attention.self.query.bias",
217+
"esm.encoder.layer.*.attention.self.key.bias",
218+
"esm.encoder.layer.*.attention.self.value.bias",
219+
),
220+
)
221+
def _unpack_qkv_bias(ctx: io.TransformCTX, qkv_bias):
222+
"""Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value."""
223+
num_heads = ctx.source.config.num_attention_heads
224+
total_size = qkv_bias.size(0) # size: [num_heads * 3 * head_dim]
225+
assert total_size % (3 * num_heads) == 0, (
226+
f"QKV bias size {total_size} not divisible by 3*num_heads {3 * num_heads}"
227+
)
228+
head_dim = total_size // (3 * num_heads)
229+
230+
qkv_bias = qkv_bias.view(num_heads, 3, head_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim]
231+
query, key, value = qkv_bias[0], qkv_bias[1], qkv_bias[2] # size: [num_heads, head_dim]
232+
233+
query = query.reshape(-1) # size: [num_heads * head_dim]
234+
key = key.reshape(-1) # size: [num_heads * head_dim]
235+
value = value.reshape(-1) # size: [num_heads * head_dim]
236+
237+
return query, key, value
238+
239+
240+
def _unpad_weights(ctx: io.TransformCTX, padded_embed):
241+
"""Remove padding from the embedding layer to get back to the original dimension."""
242+
target_embedding_dimension = ctx.target.config.vocab_size
243+
return padded_embed[:target_embedding_dimension]
244+
245+
116246
def _pad_weights(ctx: io.TransformCTX, source_embed):
117247
"""Pad the embedding layer to the new input dimension."""
118248
target_embedding_dimension = ctx.target.config.padded_vocab_size
@@ -134,6 +264,16 @@ def _pad_weights(ctx: io.TransformCTX, source_embed):
134264
target_key="lm_head.decoder.weight",
135265
)(_pad_weights)
136266

267+
_unpad_embeddings = io.state_transform(
268+
source_key="esm.embeddings.word_embeddings.weight",
269+
target_key="esm.embeddings.word_embeddings.weight",
270+
)(_unpad_weights)
271+
272+
_unpad_decoder_weights = io.state_transform(
273+
source_key="lm_head.decoder.weight",
274+
target_key="lm_head.decoder.weight",
275+
)(_unpad_weights)
276+
137277

138278
@io.state_transform(
139279
source_key="lm_head.bias",
@@ -148,3 +288,13 @@ def _pad_bias(ctx: io.TransformCTX, source_bias):
148288
)
149289
output_bias[:hf_embedding_dimension] = source_bias
150290
return output_bias
291+
292+
293+
@io.state_transform(
294+
source_key="lm_head.decoder.bias",
295+
target_key="lm_head.bias",
296+
)
297+
def _unpad_bias(ctx: io.TransformCTX, padded_bias):
298+
"""Remove padding from the bias to get back to the original dimension."""
299+
target_embedding_dimension = ctx.target.config.vocab_size
300+
return padded_bias[:target_embedding_dimension]

0 commit comments

Comments
 (0)