@@ -16,7 +16,7 @@ The ESM-2 implementation natively supports the following TransformerEngine-provi
1616| ** Sequence Packing / THD input format** | ✅ Supported |
1717| ** FP8 with THD input format** | ✅ Supported where FP8 is supported |
1818| ** Import from HuggingFace checkpoints** | ✅ Supported |
19- | ** Export to HuggingFace checkpoints** | 🚧 Under development |
19+ | ** Export to HuggingFace checkpoints** | ✅ Under development |
2020
2121See [ BioNemo Recipes] ( ../../recipes/README.md ) for more details on how to use these features to accelerate model
2222training and inference.
@@ -77,17 +77,108 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory:
7777Generate converted ESM-2 checkpoints from existing HuggingFace transformers checkpoints:
7878
7979``` bash
80- mkdir -p checkpoint_export
80+ mkdir -p hf_to_te_checkpoint_export
8181docker build -t esm2 .
8282docker run --rm -it --gpus all \
83- -v $PWD /checkpoint_export /:/workspace/bionemo/checkpoint_export \
83+ -v $PWD /hf_to_te_checkpoint_export /:/workspace/bionemo/hf_to_te_checkpoint_export \
8484 -v $HOME /.cache/huggingface/:/root/.cache/huggingface \
85- esm2 python export.py
85+ esm2 python export.py hf-to-te
8686```
8787
8888### TE to HF Transformers conversion
8989
90- (Coming soon)
90+ ``` bash
91+ MODEL_TAG=esm2_t6_8M_UR50D # specify which model to convert
92+ mkdir -p te_to_hf_checkpoint_export
93+ docker build -t esm2 .
94+ docker run --rm -it --gpus all \
95+ -v $PWD /te_to_hf_checkpoint_export/:/workspace/bionemo/te_to_hf_checkpoint_export \
96+ -v $PWD /hf_to_te_checkpoint_export/$MODEL_TAG :/workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG \
97+ -v $HOME /.cache/huggingface/:/root/.cache/huggingface \
98+ esm2 python export.py te-to-hf --checkpoint-path /workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG
99+ ```
100+
101+ ## Developer Conversion Workflow
102+
103+ This section explains how to convert between Hugging Face and Transformer Engine (TE) ESM2 model formats. The process demonstrates bidirectional conversion: from Hugging Face to TE format for optimized inference, and back to Hugging Face format for sharing and deployment. The workflow involves several key steps:
104+
105+ ### Step 1: Load Original Hugging Face Model
106+
107+ First, load the original ESM2 model from Hugging Face:
108+
109+ ``` python
110+ from transformers import AutoModelForMaskedLM
111+
112+ model_hf_original = AutoModelForMaskedLM.from_pretrained(" facebook/esm2_t6_8M_UR50D" )
113+ ```
114+
115+ This loads the pre-trained ESM2 model that will serve as our reference for comparison.
116+
117+ ### Step 2: Export to Transformer Engine Format
118+
119+ Convert the Hugging Face model to Transformer Engine format using the high-level export API:
120+
121+ ``` python
122+ from pathlib import Path
123+ from esm.export import export_hf_checkpoint
124+
125+ te_checkpoint_path = Path(" te_checkpoint" )
126+ export_hf_checkpoint(" esm2_t6_8M_UR50D" , te_checkpoint_path)
127+ ```
128+
129+ This creates a Transformer Engine checkpoint that can be used for optimized inference.
130+
131+ ### Step 3: Export Back to Hugging Face Format
132+
133+ Convert the Transformer Engine checkpoint back to Hugging Face format:
134+
135+ ``` python
136+ from esm.export import export_te_checkpoint
137+
138+ hf_export_path = Path(" hf_export" )
139+ exported_model_path = te_checkpoint_path / " esm2_t6_8M_UR50D"
140+ export_te_checkpoint(str (exported_model_path), str (hf_export_path))
141+ ```
142+
143+ This step creates a new Hugging Face model that should be functionally equivalent to the original.
144+
145+ ### Step 4: Load and Test the Exported Model
146+
147+ Load the exported model and perform validation:
148+
149+ ``` python
150+ from transformers import AutoTokenizer
151+ model_hf_exported = AutoModelForMaskedLM.from_pretrained(str (hf_export_path))
152+ tokenizer = AutoTokenizer.from_pretrained(" facebook/esm2_t6_8M_UR50D" )
153+ ```
154+
155+ ### Step 5: Validate Model Equivalence
156+
157+ Test the exported model against the original using masked language modeling:
158+
159+ ``` python
160+ import torch
161+ from transformers import DataCollatorForLanguageModeling
162+
163+ # Prepare test sequence
164+ sequence = " MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
165+ batch = tokenizer([sequence], return_tensors = " pt" )
166+ collator = DataCollatorForLanguageModeling(tokenizer = tokenizer, mlm = True , mlm_probability = 0.15 )
167+ inputs = collator([{" input_ids" : batch[" input_ids" ][0 ]}])
168+
169+ # Compare outputs
170+ with torch.no_grad():
171+ outputs_original = model_hf_original(** inputs)
172+ outputs_exported = model_hf_exported(** inputs)
173+
174+ # Check differences
175+ logits_diff = torch.abs(outputs_original.logits - outputs_exported.logits).max()
176+ print (f " Max logits difference: { logits_diff:.2e } " )
177+
178+ if outputs_original.loss is not None and outputs_exported.loss is not None :
179+ loss_diff = abs (outputs_original.loss - outputs_exported.loss)
180+ print (f " Loss difference: { loss_diff:.2e } " )
181+ ```
91182
92183## Developer Guide
93184
0 commit comments