Skip to content

Commit 2cf12c6

Browse files
committed
review fix
Signed-off-by: Ohad Mosafi <[email protected]>
1 parent 05fbec7 commit 2cf12c6

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

bionemo-recipes/models/esm2/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The ESM-2 implementation natively supports the following TransformerEngine-provi
1616
| **Sequence Packing / THD input format** | ✅ Supported |
1717
| **FP8 with THD input format** | ✅ Supported where FP8 is supported |
1818
| **Import from HuggingFace checkpoints** | ✅ Supported |
19-
| **Export to HuggingFace checkpoints** |Under development |
19+
| **Export to HuggingFace checkpoints** |Supported |
2020

2121
See [BioNemo Recipes](../../recipes/README.md) for more details on how to use these features to accelerate model
2222
training and inference.
@@ -137,7 +137,7 @@ from esm.export import export_te_checkpoint
137137

138138
hf_export_path = Path("hf_export")
139139
exported_model_path = te_checkpoint_path / "esm2_t6_8M_UR50D"
140-
export_te_checkpoint(str(exported_model_path), str(hf_export_path))
140+
export_te_checkpoint(str(exported_model_path), hf_export_path)
141141
```
142142

143143
This step creates a new Hugging Face model that should be functionally equivalent to the original.

bionemo-recipes/models/esm2/src/esm/export.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def export_hf_checkpoint(tag: str, export_path: Path):
119119
torch.cuda.empty_cache()
120120

121121

122-
def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
122+
def export_te_checkpoint(te_checkpoint_path: str, output_path: Path):
123123
"""Export a Transformer Engine checkpoint back to the original HuggingFace Facebook ESM-2 format.
124124
125125
This function converts from the NVIDIA Transformer Engine (TE) format back to the
@@ -129,7 +129,7 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
129129
130130
Args:
131131
te_checkpoint_path (str): Path to the TE checkpoint
132-
output_path (str): Output path for the converted Facebook ESM-2 format model
132+
output_path (Path): Output path for the converted Facebook ESM-2 format model
133133
"""
134134
if not Path(te_checkpoint_path).exists():
135135
raise FileNotFoundError(f"TE checkpoint {te_checkpoint_path} not found")
@@ -143,15 +143,15 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
143143

144144
tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
145145
if tokenizer_config_path.exists():
146-
shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")
146+
shutil.copy(tokenizer_config_path, output_path / "tokenizer_config.json")
147147

148148
vocab_path = Path(te_checkpoint_path) / "vocab.txt"
149149
if vocab_path.exists():
150-
shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
150+
shutil.copy(vocab_path, output_path / "vocab.txt")
151151

152152
special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
153153
if special_tokens_path.exists():
154-
shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json")
154+
shutil.copy(special_tokens_path, output_path / "special_tokens_map.json")
155155

156156
model_hf = AutoModelForMaskedLM.from_pretrained(
157157
output_path,

bionemo-recipes/models/esm2/tests/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_export_te_checkpoint_to_hf(model_name):
8787
te_model_path = te_checkpoint_path / model_name
8888

8989
hf_export_path = temp_path / "hf_export"
90-
export_te_checkpoint(str(te_model_path), str(hf_export_path))
90+
export_te_checkpoint(str(te_model_path), hf_export_path)
9191

9292
model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path))
9393

0 commit comments

Comments
 (0)