Skip to content

Commit 1ae584e

Browse files
committed
Fix ESM2 README
1 parent c59c7ba commit 1ae584e

File tree

5 files changed

+138
-42
lines changed

5 files changed

+138
-42
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
Dockerfile
22
README.md
3-
checkpoint_export/
3+
hf_to_te_checkpoint_export/
4+
te_to_hf_checkpoint_export/

bionemo-recipes/models/esm2/README.md

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The ESM-2 implementation natively supports the following TransformerEngine-provi
1616
| **Sequence Packing / THD input format** | ✅ Supported |
1717
| **FP8 with THD input format** | ✅ Supported where FP8 is supported |
1818
| **Import from HuggingFace checkpoints** | ✅ Supported |
19-
| **Export to HuggingFace checkpoints** | 🚧 Under development |
19+
| **Export to HuggingFace checkpoints** | Under development |
2020

2121
See [BioNemo Recipes](../../recipes/README.md) for more details on how to use these features to accelerate model
2222
training and inference.
@@ -77,17 +77,108 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory:
7777
Generate converted ESM-2 checkpoints from existing HuggingFace transformers checkpoints:
7878

7979
```bash
80-
mkdir -p checkpoint_export
80+
mkdir -p hf_to_te_checkpoint_export
8181
docker build -t esm2 .
8282
docker run --rm -it --gpus all \
83-
-v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \
83+
-v $PWD/hf_to_te_checkpoint_export/:/workspace/bionemo/hf_to_te_checkpoint_export \
8484
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
85-
esm2 python export.py
85+
esm2 python export.py hf-to-te
8686
```
8787

8888
### TE to HF Transformers conversion
8989

90-
(Coming soon)
90+
```bash
91+
MODEL_TAG=esm2_t6_8M_UR50D # specify which model to convert
92+
mkdir -p te_to_hf_checkpoint_export
93+
docker build -t esm2 .
94+
docker run --rm -it --gpus all \
95+
-v $PWD/te_to_hf_checkpoint_export/:/workspace/bionemo/te_to_hf_checkpoint_export \
96+
-v $PWD/hf_to_te_checkpoint_export/$MODEL_TAG:/workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG \
97+
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
98+
esm2 python export.py te-to-hf --checkpoint-path /workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG
99+
```
100+
101+
## Developer Conversion Workflow
102+
103+
This section explains how to convert between Hugging Face and Transformer Engine (TE) ESM2 model formats. The process demonstrates bidirectional conversion: from Hugging Face to TE format for optimized inference, and back to Hugging Face format for sharing and deployment. The workflow involves several key steps:
104+
105+
### Step 1: Load Original Hugging Face Model
106+
107+
First, load the original ESM2 model from Hugging Face:
108+
109+
```python
110+
from transformers import AutoModelForMaskedLM
111+
112+
model_hf_original = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
113+
```
114+
115+
This loads the pre-trained ESM2 model that will serve as our reference for comparison.
116+
117+
### Step 2: Export to Transformer Engine Format
118+
119+
Convert the Hugging Face model to Transformer Engine format using the high-level export API:
120+
121+
```python
122+
from pathlib import Path
123+
from esm.export import export_hf_checkpoint
124+
125+
te_checkpoint_path = Path("te_checkpoint")
126+
export_hf_checkpoint("esm2_t6_8M_UR50D", te_checkpoint_path)
127+
```
128+
129+
This creates a Transformer Engine checkpoint that can be used for optimized inference.
130+
131+
### Step 3: Export Back to Hugging Face Format
132+
133+
Convert the Transformer Engine checkpoint back to Hugging Face format:
134+
135+
```python
136+
from esm.export import export_te_checkpoint
137+
138+
hf_export_path = Path("hf_export")
139+
exported_model_path = te_checkpoint_path / "esm2_t6_8M_UR50D"
140+
export_te_checkpoint(str(exported_model_path), str(hf_export_path))
141+
```
142+
143+
This step creates a new Hugging Face model that should be functionally equivalent to the original.
144+
145+
### Step 4: Load and Test the Exported Model
146+
147+
Load the exported model and perform validation:
148+
149+
```python
150+
from transformers import AutoTokenizer
151+
model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path))
152+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
153+
```
154+
155+
### Step 5: Validate Model Equivalence
156+
157+
Test the exported model against the original using masked language modeling:
158+
159+
```python
160+
import torch
161+
from transformers import DataCollatorForLanguageModeling
162+
163+
# Prepare test sequence
164+
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
165+
batch = tokenizer([sequence], return_tensors="pt")
166+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
167+
inputs = collator([{"input_ids": batch["input_ids"][0]}])
168+
169+
# Compare outputs
170+
with torch.no_grad():
171+
outputs_original = model_hf_original(**inputs)
172+
outputs_exported = model_hf_exported(**inputs)
173+
174+
# Check differences
175+
logits_diff = torch.abs(outputs_original.logits - outputs_exported.logits).max()
176+
print(f"Max logits difference: {logits_diff:.2e}")
177+
178+
if outputs_original.loss is not None and outputs_exported.loss is not None:
179+
loss_diff = abs(outputs_original.loss - outputs_exported.loss)
180+
print(f"Loss difference: {loss_diff:.2e}")
181+
```
91182

92183
## Developer Guide
93184

bionemo-recipes/models/esm2/src/esm/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
155155

156156
model_hf = AutoModelForMaskedLM.from_pretrained(
157157
output_path,
158-
torch_dtype=torch.bfloat16,
158+
dtype=torch.bfloat16,
159159
trust_remote_code=False,
160160
)
161161
del model_hf

bionemo-recipes/models/esm2/tests/test_convert_reverse.py renamed to bionemo-recipes/models/esm2/tests/test_convert.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import tempfile
17-
from pathlib import Path
1816

19-
import pytest
2017
import torch
2118
from transformers import AutoModelForMaskedLM
2219

@@ -41,38 +38,6 @@ def test_convert_te_to_hf_roundtrip():
4138
torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5)
4239

4340

44-
@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
45-
def test_export_te_checkpoint_to_hf(model_name):
46-
"""Test the export function that saves TE checkpoint as HF format."""
47-
from esm.export import export_hf_checkpoint, export_te_checkpoint
48-
49-
with tempfile.TemporaryDirectory() as temp_dir:
50-
temp_path = Path(temp_dir)
51-
52-
model_hf_original = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_name}")
53-
54-
# Use export_hf_checkpoint to create TE checkpoint
55-
te_checkpoint_path = temp_path / "te_checkpoint"
56-
export_hf_checkpoint(model_name, te_checkpoint_path)
57-
te_model_path = te_checkpoint_path / model_name
58-
59-
hf_export_path = temp_path / "hf_export"
60-
export_te_checkpoint(str(te_model_path), str(hf_export_path))
61-
62-
model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path))
63-
64-
original_state_dict = model_hf_original.state_dict()
65-
exported_state_dict = model_hf_exported.state_dict()
66-
67-
# assert original_state_dict.keys() == exported_state_dict.keys()
68-
original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k}
69-
exported_keys = {k for k in exported_state_dict.keys() if "contact_head" not in k}
70-
assert original_keys == exported_keys
71-
72-
for key in original_state_dict.keys():
73-
if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key:
74-
torch.testing.assert_close(original_state_dict[key], exported_state_dict[key], atol=1e-5, rtol=1e-5)
75-
7641

7742
def test_qkv_unpacking():
7843
"""Test that QKV unpacking works correctly."""

bionemo-recipes/models/esm2/tests/test_export.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
# limitations under the License.
1515

1616

17+
from pathlib import Path
18+
import pytest
19+
import tempfile
20+
import torch
21+
from transformers import AutoModelForMaskedLM
22+
23+
1724
def test_export_hf_checkpoint(tmp_path):
1825
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
1926

@@ -63,3 +70,35 @@ def test_export_hf_checkpoint(tmp_path):
6370
assert "**Benchmark Score:** 0.37" in readme_contents, (
6471
f"README.md does not contain the expected CASP14 score line: {readme_contents}"
6572
)
73+
74+
@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
75+
def test_export_te_checkpoint_to_hf(model_name):
76+
"""Test the export function that saves TE checkpoint as HF format."""
77+
from esm.export import export_hf_checkpoint, export_te_checkpoint
78+
79+
with tempfile.TemporaryDirectory() as temp_dir:
80+
temp_path = Path(temp_dir)
81+
82+
model_hf_original = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_name}")
83+
84+
# Use export_hf_checkpoint to create TE checkpoint
85+
te_checkpoint_path = temp_path / "te_checkpoint"
86+
export_hf_checkpoint(model_name, te_checkpoint_path)
87+
te_model_path = te_checkpoint_path / model_name
88+
89+
hf_export_path = temp_path / "hf_export"
90+
export_te_checkpoint(str(te_model_path), str(hf_export_path))
91+
92+
model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path))
93+
94+
original_state_dict = model_hf_original.state_dict()
95+
exported_state_dict = model_hf_exported.state_dict()
96+
97+
# assert original_state_dict.keys() == exported_state_dict.keys()
98+
original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k}
99+
exported_keys = {k for k in exported_state_dict.keys() if "contact_head" not in k}
100+
assert original_keys == exported_keys
101+
102+
for key in original_state_dict.keys():
103+
if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key:
104+
torch.testing.assert_close(original_state_dict[key], exported_state_dict[key], atol=1e-5, rtol=1e-5)

0 commit comments

Comments
 (0)