-
Notifications
You must be signed in to change notification settings - Fork 93
ESM2: TransformerEngine to HuggingFace checkpoint support #1121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a CLI with Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant CLI as export.py (CLI)
participant HF as HuggingFace
participant TE as TransformerEngine
participant FS as Filesystem
rect #F0F8FF
note over CLI,HF: hf-to-te flow
User->>CLI: hf-to-te [--model|-m]* --output-path
CLI->>HF: load HF model(s)
CLI->>TE: export_hf_checkpoint (HF → TE)
TE->>FS: save TE checkpoint(s)
CLI-->>User: status messages
end
rect #F7FFF0
note over CLI,TE: te-to-hf flow
User->>CLI: te-to-hf --checkpoint-path --output-path
CLI->>FS: validate checkpoint path
CLI->>TE: load TE checkpoint (NVEsmForMaskedLM)
TE->>CLI: provide TE state & config
CLI->>HF: convert_esm_te_to_hf (TE → HF)
HF->>FS: save HF model + tokenizer/vocab
CLI-->>User: status messages
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
🧰 Additional context used🧬 Code graph analysis (2)bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
cfb5cae to
f2f2247
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (16)
models/esm2/export.py (3)
34-41: Drop redundant validation; argparse choices already enforce valid tags
choices=ESM_TAGSguarantees validity, so the manual check below is unreachable and can be removed.- if args.model: - if args.model not in ESM_TAGS: - print(f"Error: '{args.model}' is not a valid model tag.\nAvailable models: {', '.join(ESM_TAGS)}") - return - export_hf_checkpoint(args.model, Path(args.output_path)) + if args.model: + export_hf_checkpoint(args.model, Path(args.output_path))
42-47: Create the output directory up front for clearer failuresFail fast on permission/path issues rather than during export; harmless if it already exists.
args = parser.parse_args() + # Ensure base output path exists + Path(args.output_path).mkdir(parents=True, exist_ok=True)Also applies to: 49-49
58-59: Prefer logging over printsConsider
logging.info("Converting %s...", tag)for consistency and easier diagnostics. No functional change required.models/esm2/export_te_checkpoint_to_hf.py (2)
41-43: Tighten CLI descriptionMinor copy edit (“HuggingFace Facebook … hosted on Hugging Face” is redundant).
- description="Convert ESM2 models from Transformer Engine format back to HuggingFace Facebook ESM-2 format hosted on Hugging Face" + description="Convert ESM2 models from Transformer Engine (TE) format back to the original Hugging Face Facebook ESM-2 format"
22-38: Consider using logging instead of print and returning non-zero on failureReplace prints with
loggingandsys.exit(1)in the except path to signal failure in scripts/CI.models/esm2/src/esm/export.py (3)
113-121: Preservearchitecturesin config for better HF UXSetting
"architectures": ["EsmForMaskedLM"]helps downstream loading tools and Hub metadata.if config_path.exists(): with open(config_path, "r") as f: config = json.load(f) config.pop("auto_map", None) config["model_type"] = "esm" + config.setdefault("architectures", ["EsmForMaskedLM"]) with open(config_path, "w") as f: json.dump(config, f, indent=2, sort_keys=True)
123-130: Make the smoke test optional to avoid OOMs on large checkpoints (e.g., 3B/15B)Loading the full model in bf16 can be prohibitive in limited environments. Gate behind a flag with a safe default.
-def export_te_checkpoint(te_checkpoint_path: str, output_path: str): +def export_te_checkpoint(te_checkpoint_path: str, output_path: str, *, smoke_test: bool = False): @@ - model_hf = AutoModelForMaskedLM.from_pretrained( - output_path, - torch_dtype=torch.bfloat16, - trust_remote_code=False, - ) - del model_hf - gc.collect() - torch.cuda.empty_cache() + if smoke_test: + model_hf = AutoModelForMaskedLM.from_pretrained( + output_path, + torch_dtype=torch.bfloat16, + trust_remote_code=False, + ) + del model_hf + gc.collect() + torch.cuda.empty_cache()Callers (e.g., the new CLI) can pass
smoke_test=Truewhen resources allow.
100-111: Minor DRY: copy optional tokenizer artifacts via a loopReduces repetition; behavior unchanged.
- tokenizer_path = Path(te_checkpoint_path) / "tokenizer.json" - if tokenizer_path.exists(): - shutil.copy(tokenizer_path, Path(output_path) / "tokenizer.json") - - tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json" - if tokenizer_config_path.exists(): - shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json") - - vocab_path = Path(te_checkpoint_path) / "vocab.txt" - if vocab_path.exists(): - shutil.copy(vocab_path, Path(output_path) / "vocab.txt") + for fname in ("tokenizer.json", "tokenizer_config.json", "vocab.txt"): + src = Path(te_checkpoint_path) / fname + if src.exists(): + shutil.copy(src, Path(output_path) / fname)models/esm2/README.md (1)
53-58: Clarify exported path handling.Consider documenting that export_hf_checkpoint writes to te_checkpoint_path / so users can substitute different tags without editing the path literal. A short note avoids confusion.
models/esm2/src/esm/convert.py (3)
115-118: Avoid mutating model structure unless necessary.Deleting contact_head is safe for state_dict parity, but it’s not required for EsmForMaskedLM functionality. Prefer leaving the module intact and excluding its params via state_dict_ignored_entries to reduce surprises for downstream users introspecting model_hf.esm.
- # Remove contact_head since it's not present in TE models - if hasattr(model_hf.esm, "contact_head"): - delattr(model_hf.esm, "contact_head") + # Keep contact_head module; its params are excluded below.
119-129: Consider ignoring non-parameter buffers as well.If present, buffers like position_ids or inv_freq can be safely ignored to keep the transform strict on learnable weights only.
state_dict_ignored_entries=[ "lm_head.decoder.weight", "esm.contact_head.regression.weight", "esm.contact_head.regression.bias", + "esm.embeddings.position_ids", + "esm.encoder.layer.*.attention.self.rotary_emb.inv_freq", ],
148-160: Docstrings mention padding; update to describe QKV packing.The functions pack QKV and do not pad. Update the docstrings to avoid confusion.
Would you like me to open a follow-up PR to update the docstrings and inline comments in _pack_qkv_weight/_pack_qkv_bias for accuracy?
Also applies to: 170-181
models/esm2/tests/test_convert_reverse.py (4)
25-42: Narrow the TE-layer check to encoder modules to avoid false positives.Linear/LayerNorm may legitimately exist outside the fused TE encoder (e.g., embeddings/lm_head). Scope to ".encoder.layers." to reduce noise.
- for name, module in model_te.named_modules(): + for name, module in model_te.named_modules(): + if ".encoder.layers." not in name: + continue if isinstance(module, nn.Linear): vanilla_layers_found.append(f"Linear layer found in {name}") if isinstance(module, nn.LayerNorm): vanilla_layers_found.append(f"LayerNorm layer found in {name}")
64-81: Mark export round-trip as slow/internet and guard TE import.This test hits disk and remote hub; mark and guard to play well in CI.
-@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"]) +@requires_internet +@pytest.mark.slow +@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"]) def test_export_te_checkpoint_to_hf(model_name): @@ - from esm.export import export_hf_checkpoint, export_te_checkpoint + from esm.export import export_hf_checkpoint, export_te_checkpoint + pytest.importorskip("transformer_engine", reason="Transformer Engine required for TE checkpoint export")
119-132: Strengthen config assertions; ensure TE-only keys are stripped.Enable negative assertions to lock in expected behavior from convert_esm_te_to_hf.
assert model_hf_converted.config.vocab_size == model_hf.config.vocab_size - - # assert not hasattr(model_hf_converted.config, 'qkv_weight_interleaved') - # assert not hasattr(model_hf_converted.config, 'encoder_activation') - # assert not hasattr(model_hf_converted.config, 'attn_input_format') - # assert not hasattr(model_hf_converted.config, 'fuse_qkv_params') - # assert not hasattr(model_hf_converted.config, 'micro_batch_size') - # assert not hasattr(model_hf_converted.config, 'max_seq_length') + assert not hasattr(model_hf_converted.config, "qkv_weight_interleaved") + assert not hasattr(model_hf_converted.config, "encoder_activation") + assert not hasattr(model_hf_converted.config, "attn_input_format") + assert not hasattr(model_hf_converted.config, "fuse_qkv_params") + assert not hasattr(model_hf_converted.config, "micro_batch_size") + assert not hasattr(model_hf_converted.config, "max_seq_length")
59-62: Tolerance is tight but fine for fp32; consider dtype harmonization to avoid precision flakes.If CI flakes occur, cast both state_dicts to float32 before comparisons to avoid bfloat16/float16 noise.
I can add a helper to clone-and-cast tensors before torch.testing.assert_close; want me to send a small patch?
Also applies to: 92-95, 114-116
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
models/esm2/README.md(1 hunks)models/esm2/export.py(2 hunks)models/esm2/export_te_checkpoint_to_hf.py(1 hunks)models/esm2/src/esm/convert.py(4 hunks)models/esm2/src/esm/export.py(2 hunks)models/esm2/tests/test_convert_reverse.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
models/esm2/export_te_checkpoint_to_hf.py (2)
models/esm2/src/esm/export.py (1)
export_te_checkpoint(76-130)models/esm2/export.py (1)
main(32-59)
models/esm2/src/esm/export.py (1)
models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(51-76)convert_esm_te_to_hf(79-137)
models/esm2/tests/test_convert_reverse.py (2)
models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(51-76)convert_esm_te_to_hf(79-137)models/esm2/src/esm/export.py (2)
export_hf_checkpoint(28-73)export_te_checkpoint(76-130)
models/esm2/export.py (1)
models/esm2/src/esm/export.py (1)
export_hf_checkpoint(28-73)
🔇 Additional comments (3)
models/esm2/src/esm/export.py (1)
24-26: Imports for reverse conversion look correctBrings in
convert_esm_te_to_hfandNVEsmForMaskedLMneeded for TE→HF path.models/esm2/src/esm/convert.py (2)
20-20: Good addition: explicit HF types.Importing EsmConfig/EsmForMaskedLM locally removes reliance on AutoModel classes here.
47-49: Reverse mapping is fine.The simple inversion matches the one-to-one entries; QKV is handled via transforms.
| import tempfile | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch import nn | ||
| from transformers import AutoModelForMaskedLM | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add imports/markers for optional dependencies and offline CI.
Guard tests that require internet/TE, and avoid surprises in offline CI.
import tempfile
from pathlib import Path
import pytest
import torch
from torch import nn
from transformers import AutoModelForMaskedLM
+import os
+
+# Skip if offline (HF env flags)
+OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE")
+requires_internet = pytest.mark.skipif(OFFLINE, reason="Requires internet to download HF models")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import tempfile | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModelForMaskedLM | |
| import tempfile | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModelForMaskedLM | |
| import os | |
| # Skip if offline (HF env flags) | |
| OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE") | |
| requires_internet = pytest.mark.skipif( | |
| OFFLINE, reason="Requires internet to download HF models" | |
| ) |
🤖 Prompt for AI Agents
In models/esm2/tests/test_convert_reverse.py around lines 16 to 23, the test
imports assume optional dependencies (internet access, TorchEngine/TE) are
available; add guards so the test is skipped in offline CI or when optional
packages are missing. Use pytest.importorskip for optional libraries (or wrap
imports in try/except and call pytest.skip), and add appropriate
pytest.mark.skipif conditions (e.g., skip when an environment variable indicates
offline CI or when required network access is unavailable) so the test only runs
when dependencies and network are present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make sure you rebase and run CI?
| def test_esm_model_has_all_te_layers(): | ||
| """Test that the converted TE model doesn't contain vanilla PyTorch layers.""" | ||
| from esm.convert import convert_esm_hf_to_te | ||
|
|
||
| model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") | ||
| model_te = convert_esm_hf_to_te(model_hf) | ||
| vanilla_layers_found = [] | ||
| for name, module in model_te.named_modules(): | ||
| if isinstance(module, nn.Linear): | ||
| vanilla_layers_found.append(f"Linear layer found in {name}") | ||
| if isinstance(module, nn.LayerNorm): | ||
| vanilla_layers_found.append(f"LayerNorm layer found in {name}") | ||
| if vanilla_layers_found: | ||
| print("ERROR: Found vanilla PyTorch layers in converted TE model:") | ||
| for error in vanilla_layers_found: | ||
| print(f"WARNING: {error}") | ||
| assert not vanilla_layers_found, f"Found {len(vanilla_layers_found)} vanilla layers in converted model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a duplicate test, right? we have test_esm_model_has_all_te_layers in test_modeling_esm_te? do we need this in a test that exports from the TE model to HF?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this test is redundant, I removed it
| # assert not hasattr(model_hf_converted.config, 'qkv_weight_interleaved') | ||
| # assert not hasattr(model_hf_converted.config, 'encoder_activation') | ||
| # assert not hasattr(model_hf_converted.config, 'attn_input_format') | ||
| # assert not hasattr(model_hf_converted.config, 'fuse_qkv_params') | ||
| # assert not hasattr(model_hf_converted.config, 'micro_batch_size') | ||
| # assert not hasattr(model_hf_converted.config, 'max_seq_length') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason these are commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, forgot to uncomment that. I rewrote this section
Signed-off-by: Ohad Mosafi <[email protected]>
22c7739 to
27365b0
Compare
Signed-off-by: Ohad Mosafi <[email protected]>
27365b0 to
3af0c7d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
models/esm2/tests/test_convert_reverse.py (1)
16-23: Skip gracefully when offline or TE is unavailable.These tests download HF models and rely on Transformer Engine; add guards to avoid CI flakes.
import tempfile from pathlib import Path import pytest import torch from transformers import AutoModelForMaskedLM +import os +from importlib.util import find_spec + +# Skip if offline (HF env flags) +OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE") +requires_internet = pytest.mark.skipif(OFFLINE, reason="Requires internet to download HF models") + +# Skip if Transformer Engine is missing +TE_MISSING = find_spec("transformer_engine") is None +requires_te = pytest.mark.skipif(TE_MISSING, reason="Requires transformer_engine")
🧹 Nitpick comments (7)
models/esm2/export_te_checkpoint_to_hf.py (3)
26-27: Validate the checkpoint path is a directory (not just exists).Avoid confusing errors when a file path is passed.
- if not Path(te_checkpoint_dir).exists(): - raise FileNotFoundError(f"TE checkpoint {te_checkpoint_dir} not found") + te_path = Path(te_checkpoint_dir) + if not te_path.exists(): + raise FileNotFoundError(f"TE checkpoint {te_checkpoint_dir} not found") + if not te_path.is_dir(): + raise NotADirectoryError(f"TE checkpoint path must be a directory: {te_checkpoint_dir}")
33-35: Preserve original traceback when re-raising.Use a bare raise to keep the original stack.
- except Exception as e: - print(f"Error converting {te_checkpoint_dir}: {e}") - raise e + except Exception as e: + print(f"Error converting {te_checkpoint_dir}: {e}") + raise
45-49: Clarify that --model expects a directory.Small wording tweak improves UX.
- help="Path to the TE checkpoint.", + help="Path to the TE checkpoint directory.",models/esm2/tests/test_convert_reverse.py (4)
24-27: Mark test as requiring internet and TE.- def test_convert_te_to_hf_roundtrip(): +@requires_internet +@requires_te +def test_convert_te_to_hf_roundtrip():
44-47: Mark test as requiring internet and TE.-@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"]) -def test_export_te_checkpoint_to_hf(model_name): +@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"]) +@requires_internet +@requires_te +def test_export_te_checkpoint_to_hf(model_name):
77-80: Mark test as requiring internet and TE.-def test_qkv_unpacking(): +@requires_internet +@requires_te +def test_qkv_unpacking():
99-103: Mark test as requiring internet and TE.-def test_config_conversion(): +@requires_internet +@requires_te +def test_config_conversion():
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
models/esm2/README.md(1 hunks)models/esm2/export.py(2 hunks)models/esm2/export_te_checkpoint_to_hf.py(1 hunks)models/esm2/src/esm/convert.py(5 hunks)models/esm2/src/esm/export.py(2 hunks)models/esm2/tests/test_convert_reverse.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- models/esm2/export.py
- models/esm2/src/esm/export.py
- models/esm2/README.md
- models/esm2/src/esm/convert.py
🧰 Additional context used
🧬 Code graph analysis (2)
models/esm2/tests/test_convert_reverse.py (2)
models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(51-76)convert_esm_te_to_hf(79-140)models/esm2/src/esm/export.py (2)
export_hf_checkpoint(28-73)export_te_checkpoint(76-118)
models/esm2/export_te_checkpoint_to_hf.py (1)
models/esm2/src/esm/export.py (1)
export_te_checkpoint(76-118)
🔇 Additional comments (1)
models/esm2/export_te_checkpoint_to_hf.py (1)
16-19: Ensure 'esm' import works when running the script directlyFile: models/esm2/export_te_checkpoint_to_hf.py — Repo search found no esm/export.py or convert.py and no PYTHONPATH/CI hints beyond pyproject addopts. Running
python models/esm2/export_te_checkpoint_to_hf.pymay raise ModuleNotFoundError: esm unless models/esm2/src (or the package root) is on PYTHONPATH or the package is installed. Options: export PYTHONPATH in CI/docs, make the script runnable viapython -mor an installed entrypoint, or add a short sys.path fallback in the script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
models/esm2/src/esm/convert.py (1)
97-113: Map max_seq_length→max_position_embeddings and drop export-only dtype; safer EsmConfig.Right now
max_seq_lengthis popped but not mapped; if TE configs don't already carrymax_position_embeddings, HF will fall back to defaults. Also striptorch_dtypeto avoid passing runtime dtype hints into config construction.Apply:
- # Remove TE-specific config options - te_specific_keys = [ + # Map TE naming to HF before cleanup + if "max_position_embeddings" not in hf_config_dict and "max_seq_length" in hf_config_dict: + hf_config_dict["max_position_embeddings"] = hf_config_dict["max_seq_length"] + + # Remove TE-specific and export-only config options + te_specific_keys = [ "qkv_weight_interleaved", "encoder_activation", "attn_input_format", "fuse_qkv_params", "micro_batch_size", "max_seq_length", "model_type", "auto_map", + "torch_dtype", ] for key in te_specific_keys: hf_config_dict.pop(key, None) hf_config_dict["model_type"] = "esm"
🧹 Nitpick comments (6)
models/esm2/src/esm/convert.py (4)
151-163: Make QKV weight packing shape-safe and fix the docstring.The current code uses
-1and a “Pad …” docstring, which is misleading for a pack operation. Add explicit shape checks to mirror the safety of the unpack path.Apply:
@io.state_transform( @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): - """Pad the embedding layer to the new input dimension.""" - concat_weights = torch.cat((query, key, value), dim=0) - input_shape = concat_weights.size() - num_heads = ctx.target.config.num_attention_heads - # transpose weights - # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] - # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] - concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1]) - concat_weights = concat_weights.transpose(0, 1).contiguous() - concat_weights = concat_weights.view(*input_shape) - return concat_weights + """Pack separate Q,K,V weights into TE's fused [hidden_size*3, input_dim] layout.""" + concat = torch.cat((query, key, value), dim=0) # [3*hidden_size, input_dim] + num_heads = ctx.target.config.num_attention_heads + total_rows, input_dim = concat.size() + if total_rows % (3 * num_heads) != 0: + raise ValueError(f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}") + head_dim = total_rows // (3 * num_heads) + fused = concat.view(3, num_heads, head_dim, input_dim).transpose(0, 1).contiguous() + return fused.reshape(total_rows, input_dim)
173-185: Mirror the packing safety for biases and correct the docstring.Same rationale as weights.
Apply:
def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): - """Pad the embedding layer to the new input dimension.""" - concat_biases = torch.cat((query, key, value), dim=0) - input_shape = concat_biases.size() - num_heads = ctx.target.config.num_attention_heads - # transpose biases - # [num_splits_model_parallel * attention head size * #attention heads] - # --> [attention head size * num_splits_model_parallel * #attention heads] - concat_biases = concat_biases.view(3, num_heads, -1) - concat_biases = concat_biases.transpose(0, 1).contiguous() - concat_biases = concat_biases.view(*input_shape) - return concat_biases + """Pack separate Q,K,V biases into TE's fused [hidden_size*3] layout.""" + concat = torch.cat((query, key, value), dim=0) # [3*hidden_size] + num_heads = ctx.target.config.num_attention_heads + total = concat.size(0) + if total % (3 * num_heads) != 0: + raise ValueError(f"QKV bias size {total} not divisible by 3*num_heads {3*num_heads}") + head_dim = total // (3 * num_heads) + fused = concat.view(3, num_heads, head_dim).transpose(0, 1).contiguous() + return fused.reshape(total)
195-211: Prefer explicit ValueError over assert for shape validation.Asserts can be stripped with Python -O. Use explicit checks for reliability.
Apply:
- assert total_rows % (3 * num_heads) == 0, ( - f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}" - ) + if total_rows % (3 * num_heads) != 0: + raise ValueError( + f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}" + )
222-238: Same: raise ValueError instead of assert for bias shape checks.Apply:
- assert total_size % (3 * num_heads) == 0, ( - f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}" - ) + if total_size % (3 * num_heads) != 0: + raise ValueError( + f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}" + )models/esm2/export_te_checkpoint_to_hf.py (1)
22-38: Avoid noisy try/except or re-raise plainly; optionally guard against same input/output path.
- Printing and re-raising adds noise without better diagnostics. Either let exceptions bubble or
raisewithout binding.- Prevent accidental mixing of TE and HF artifacts by forbidding identical input/output directories.
Apply:
def main(te_checkpoint_dir: str, output_dir: str): @@ - try: - export_te_checkpoint(te_checkpoint_dir, output_dir) - print(f"Successfully exported {te_checkpoint_dir} to HuggingFace Facebook ESM-2 format at {output_dir}") - except Exception as e: - print(f"Error converting {te_checkpoint_dir}: {e}") - raise e + if Path(te_checkpoint_dir).resolve() == Path(output_dir).resolve(): + raise ValueError("output_dir must differ from te_checkpoint_dir to avoid mixing artifacts.") + + export_te_checkpoint(te_checkpoint_dir, output_dir) + print(f"Successfully exported {te_checkpoint_dir} to HuggingFace Facebook ESM-2 format at {output_dir}")models/esm2/src/esm/export.py (1)
76-118: Reduce memory during smoke test load.Loading the converted HF model just to validate can be made lighter to avoid transient OOMs in CI.
Apply:
- model_hf = AutoModelForMaskedLM.from_pretrained( - output_path, - torch_dtype=torch.bfloat16, - trust_remote_code=False, - ) + model_hf = AutoModelForMaskedLM.from_pretrained( + output_path, + torch_dtype=torch.bfloat16, + trust_remote_code=False, + low_cpu_mem_usage=True, + device_map="cpu", + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
models/esm2/README.md(1 hunks)models/esm2/export_te_checkpoint_to_hf.py(1 hunks)models/esm2/src/esm/convert.py(5 hunks)models/esm2/src/esm/export.py(2 hunks)models/esm2/tests/test_convert_reverse.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- models/esm2/tests/test_convert_reverse.py
- models/esm2/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
models/esm2/export_te_checkpoint_to_hf.py (2)
models/esm2/src/esm/export.py (1)
export_te_checkpoint(76-117)models/esm2/export.py (1)
main(32-59)
models/esm2/src/esm/export.py (1)
models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(51-76)convert_esm_te_to_hf(79-140)
🔇 Additional comments (2)
models/esm2/src/esm/convert.py (1)
47-49: Reverse mapping looks correct and scoped well.Simple and reliable way to derive TE→HF key mapping; avoids drift.
models/esm2/export_te_checkpoint_to_hf.py (1)
45-55: CLI help is clear; required model arg prevents Path(None) crash.No further changes needed.
2073094 to
c59c7ba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
bionemo-recipes/models/esm2/src/esm/export.py (1)
140-140: Consider explicit dtype for consistency.The TE model is loaded without specifying a dtype, while the smoke test at line 158 uses
torch.bfloat16. For consistency withexport_hf_checkpoint(line 114) and to avoid potential dtype mismatches, consider specifying the dtype explicitly.Apply this diff:
- model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path) + model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path, torch_dtype=torch.bfloat16)bionemo-recipes/models/esm2/export.py (2)
67-69: Remove redundant validation.This validation is unnecessary because
argparsewithchoices=ESM_TAGSalready ensures thatargs.modelis valid. The code at line 66 only executes whenargs.modelis not None and has passed the choices validation.Apply this diff:
if args.model: - if args.model not in ESM_TAGS: - print(f"Error: '{args.model}' is not a valid model tag.\nAvailable models: {', '.join(ESM_TAGS)}") - return print(f"Converting {args.model} from HuggingFace to Transformer Engine format") export_hf_checkpoint(args.model, Path(args.output_path))
78-79: Remove duplicate path validation.This path existence check duplicates the validation already performed in
export_te_checkpointat lines 134-135 ofbionemo-recipes/models/esm2/src/esm/export.py. The function will raise a clearFileNotFoundErrorif the path doesn't exist.Apply this diff:
print(f"Converting {args.checkpoint_path} from Transformer Engine to HuggingFace format") - if not Path(args.checkpoint_path).exists(): - raise FileNotFoundError(f"TE checkpoint {args.checkpoint_path} not found") export_te_checkpoint(args.checkpoint_path, Path(args.output_path))bionemo-recipes/models/esm2/tests/test_convert_reverse.py (1)
67-67: Remove commented code.The commented assertion at line 67 is replaced by the logic at lines 68-70. Remove the dead code.
Apply this diff:
- # assert original_state_dict.keys() == exported_state_dict.keys() original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
bionemo-recipes/models/esm2/export.py(2 hunks)bionemo-recipes/models/esm2/src/esm/convert.py(7 hunks)bionemo-recipes/models/esm2/src/esm/export.py(2 hunks)bionemo-recipes/models/esm2/tests/test_convert_reverse.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/models/esm2/export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)
export_te_checkpoint(122-163)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (1)
convert_esm_te_to_hf(76-137)bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
NVEsmForMaskedLM(415-508)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Analyze (rust)
🔇 Additional comments (22)
bionemo-recipes/models/esm2/src/esm/export.py (5)
25-26: LGTM!The new imports support the TE→HF conversion functionality.
134-135: LGTM!Appropriate validation of input path existence.
144-154: LGTM!Optional tokenizer file copying is appropriate and handles missing files gracefully.
156-163: Smoke test validates model loading.The smoke test successfully verifies that the exported HF model can be loaded with the standard Transformers API. Memory cleanup is properly performed.
142-142: Inspectsave_pretrainedvia PythonRun the following to dump its source and verify directory‐creation behavior:
#!/bin/bash python - << 'EOF' import inspect from transformers import PreTrainedModel print(inspect.getsource(PreTrainedModel.save_pretrained)) EOFThis will reveal whether missing output directories are auto‐created; if not, add an explicit
os.makedirs(output_path, exist_ok=True)beforesave_pretrained.bionemo-recipes/models/esm2/export.py (2)
16-19: LGTM!Imports support the new bidirectional CLI functionality.
65-80: Verify argument name consistency after fix.If the argument naming is standardized to use hyphens (per previous comment), ensure attribute access uses underscores:
args.output_pathbecomesargs.output_path(argparse converts hyphens to underscores in attribute names).bionemo-recipes/models/esm2/tests/test_convert_reverse.py (6)
24-42: LGTM!Comprehensive roundtrip test with appropriate exclusions for contact_head (not preserved in TE), _extra_state, and inv_freq buffers. Tolerance settings are reasonable for float precision.
44-75: Comprehensive export roundtrip validation.The test properly validates the full export workflow including checkpoint persistence and reloading. Temp directory usage ensures test isolation.
77-97: LGTM!Thorough validation of QKV unpacking across all transformer layers with appropriate numerical tolerance.
99-126: LGTM!Comprehensive config conversion validation ensuring field preservation and proper cleanup of TE-specific fields.
160-163: Padding validation is thorough.The test correctly validates that padded embedding rows are zero-initialized. Note that the comment at line 160 mentions "min values (for bias)" but the test only validates embedding padding; bias padding validation could be added if needed.
128-164: LGTM!Comprehensive validation of padding/unpadding operations for embeddings, decoder weights, and biases with proper shape and value checks.
bionemo-recipes/models/esm2/src/esm/convert.py (9)
20-20: LGTM!Required imports for HF model reconstruction in the TE→HF conversion path.
44-45: LGTM!Clean reversal of the mapping dictionary to support TE→HF parameter name translation.
76-137: Well-structured TE→HF conversion.The function properly handles config conversion, model initialization, and parameter transformation. The removal of TE-specific config keys and contact_head is appropriate since contact_head is not preserved in TE models.
148-159: LGTM!Properly retrieves
num_headsfrom the target config for QKV weight packing.
170-181: LGTM!Consistent with QKV weight packing, properly uses config for
num_heads.
184-208: LGTM!Correct unpacking logic that inverts the QKV fusion performed in
_pack_qkv_weight. The shape assertion provides good validation.
211-235: LGTM!Correct bias unpacking that mirrors the weight unpacking logic and inverts
_pack_qkv_bias.
238-273: LGTM!Clean implementation of unpadding logic that correctly inverts the padding operations. Appropriate application to both embeddings and decoder weights.
291-298: LGTM!Correct unpadding of bias tensor, inverting the padding operation from
_pad_bias.
Signed-off-by: Ohad Mosafi <[email protected]>
1ae584e to
b6af4e8
Compare
Signed-off-by: Ohad Mosafi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
bionemo-recipes/models/esm2/.dockerignore(1 hunks)bionemo-recipes/models/esm2/README.md(2 hunks)bionemo-recipes/models/esm2/src/esm/export.py(2 hunks)bionemo-recipes/models/esm2/tests/test_convert.py(1 hunks)bionemo-recipes/models/esm2/tests/test_export.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
bionemo-recipes/models/esm2/tests/test_convert.py (1)
bionemo-recipes/models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(48-73)convert_esm_te_to_hf(76-137)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (1)
convert_esm_te_to_hf(76-137)bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
NVEsmForMaskedLM(415-508)
bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
export_hf_checkpoint(54-119)export_te_checkpoint(122-163)
🔇 Additional comments (1)
bionemo-recipes/models/esm2/.dockerignore (1)
3-4: Ignore rules look goodKeeping the TE↔HF export outputs out of build context avoids Docker bloat.
| tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json" | ||
| if tokenizer_config_path.exists(): | ||
| shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json") | ||
|
|
||
| vocab_path = Path(te_checkpoint_path) / "vocab.txt" | ||
| if vocab_path.exists(): | ||
| shutil.copy(vocab_path, Path(output_path) / "vocab.txt") | ||
|
|
||
| special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json" | ||
| if special_tokens_path.exists(): | ||
| shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy all tokenizer artifacts during TE→HF export
Only copying tokenizer_config.json, vocab.txt, and special_tokens_map.json drops other tokenizer assets (tokenizer.json, added_tokens.json, merges.txt, etc.). If the TE checkpoint was trained with added tokens or relies on the fast tokenizer, the exported HF checkpoint will be broken or silently lose vocabulary entries. Please copy every tokenizer file alongside the model weights.
- tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
- if tokenizer_config_path.exists():
- shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")
-
- vocab_path = Path(te_checkpoint_path) / "vocab.txt"
- if vocab_path.exists():
- shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
-
- special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
- if special_tokens_path.exists():
- shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json")
+ tokenizer_files = [
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "added_tokens.json",
+ "vocab.txt",
+ "merges.txt",
+ ]
+ for filename in tokenizer_files:
+ src = Path(te_checkpoint_path) / filename
+ if src.exists():
+ shutil.copy(src, Path(output_path) / filename)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json" | |
| if tokenizer_config_path.exists(): | |
| shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json") | |
| vocab_path = Path(te_checkpoint_path) / "vocab.txt" | |
| if vocab_path.exists(): | |
| shutil.copy(vocab_path, Path(output_path) / "vocab.txt") | |
| special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json" | |
| if special_tokens_path.exists(): | |
| shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json") | |
| tokenizer_files = [ | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "special_tokens_map.json", | |
| "added_tokens.json", | |
| "vocab.txt", | |
| "merges.txt", | |
| ] | |
| for filename in tokenizer_files: | |
| src = Path(te_checkpoint_path) / filename | |
| if src.exists(): | |
| shutil.copy(src, Path(output_path) / filename) |
Signed-off-by: Ohad Mosafi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/export.py (1)
62-62: Remove trailing whitespace.Line 62 has trailing whitespace after the closing parenthesis.
Apply this diff:
- ) + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
bionemo-recipes/models/esm2/export.py(2 hunks)bionemo-recipes/models/esm2/tests/test_export.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- bionemo-recipes/models/esm2/tests/test_export.py
🧰 Additional context used
🧬 Code graph analysis (1)
bionemo-recipes/models/esm2/export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)
export_te_checkpoint(122-163)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Analyze (rust)
🔇 Additional comments (2)
bionemo-recipes/models/esm2/export.py (2)
16-19: LGTM! Imports are correct.The imports are properly organized and align with the module structure. Both
export_hf_checkpointandexport_te_checkpointare correctly imported fromesm.export.
38-61: LGTM! Previous issues resolved.The argument naming inconsistency and incorrect help text from the previous review have been properly addressed:
- Both subparsers now consistently use
--output-path(hyphen style)- The te-to-hf help text correctly identifies the checkpoint as "Transformer Engine"
7fb6169 to
f4abfba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)
144-154: Copy all tokenizer artifacts during TE→HF export.Only copying
tokenizer_config.json,vocab.txt, andspecial_tokens_map.jsondrops other tokenizer assets (tokenizer.json,added_tokens.json,merges.txt, etc.). If the TE checkpoint includes added tokens or relies on the fast tokenizer, the exported HF checkpoint will be broken or silently lose vocabulary entries.Apply this diff to copy all tokenizer files:
- tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json" - if tokenizer_config_path.exists(): - shutil.copy(tokenizer_config_path, output_path / "tokenizer_config.json") - - vocab_path = Path(te_checkpoint_path) / "vocab.txt" - if vocab_path.exists(): - shutil.copy(vocab_path, output_path / "vocab.txt") - - special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json" - if special_tokens_path.exists(): - shutil.copy(special_tokens_path, output_path / "special_tokens_map.json") + tokenizer_files = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "added_tokens.json", + "vocab.txt", + "merges.txt", + ] + for filename in tokenizer_files: + src = Path(te_checkpoint_path) / filename + if src.exists(): + shutil.copy(src, output_path / filename)
🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/tests/test_export.py (1)
74-103: Expand test coverage for TE-to-HF export.The test validates basic roundtrip conversion but lacks coverage for:
- Tokenizer file preservation (the past review comment flagged missing tokenizer files)
- Config attribute validation beyond state_dict comparison
- Error handling when TE checkpoint is missing or corrupted
- Multiple model sizes (currently only tests esm2_t6_8M_UR50D)
Consider adding these test cases:
@pytest.mark.parametrize("model_name", [ "esm2_t6_8M_UR50D", "esm2_t12_35M_UR50D", # Add more model sizes ]) def test_export_te_checkpoint_to_hf(model_name): # ... existing test code ... # Add tokenizer validation original_tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}") exported_tokenizer = AutoTokenizer.from_pretrained(str(hf_export_path)) assert original_tokenizer.vocab == exported_tokenizer.vocab def test_export_te_checkpoint_missing_path(): """Test error handling for missing TE checkpoint.""" from esm.export import export_te_checkpoint with pytest.raises(FileNotFoundError): export_te_checkpoint("/nonexistent/path", Path("/tmp/output"))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
bionemo-recipes/models/esm2/README.md(2 hunks)bionemo-recipes/models/esm2/src/esm/export.py(2 hunks)bionemo-recipes/models/esm2/tests/test_export.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (2)
convert_esm_hf_to_te(48-73)convert_esm_te_to_hf(76-137)bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
NVEsmForMaskedLM(415-508)
bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
export_hf_checkpoint(54-119)export_te_checkpoint(122-163)
🔇 Additional comments (6)
bionemo-recipes/models/esm2/README.md (2)
80-99: LGTM! Clear conversion commands.The updated conversion commands correctly reflect the new CLI structure with
hf-to-teandte-to-hfsubcommands, and the volume mount paths are properly separated for each conversion direction.
101-181: LGTM! Comprehensive developer workflow.The developer workflow section provides a complete end-to-end example with proper validation steps. The code examples correctly demonstrate:
- Loading original HF models
- Exporting to TE format
- Converting back to HF format
- Validating equivalence with numerical comparison
bionemo-recipes/models/esm2/src/esm/export.py (4)
25-26: LGTM! Correct imports for TE-to-HF conversion.The imports correctly bring in the conversion function and TE model class needed for the export functionality.
122-133: LGTM! Clear function signature and documentation.The docstring clearly explains the function's purpose and the distinction between TE format (also HF-compatible) and the original Facebook ESM-2 format.
134-142: LGTM! Proper error handling and conversion flow.The function correctly checks for checkpoint existence, loads the TE model, converts it, and saves the result.
156-163: LGTM! Proper smoke test and cleanup.The function correctly validates that the exported model can be loaded and properly cleans up resources with garbage collection and CUDA cache clearing.
Signed-off-by: Ohad Mosafi <[email protected]>
f4abfba to
2cf12c6
Compare
| mkdir -p hf_to_te_checkpoint_export | ||
| docker build -t esm2 . | ||
| docker run --rm -it --gpus all \ | ||
| -v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \ | ||
| -v $PWD/hf_to_te_checkpoint_export/:/workspace/bionemo/hf_to_te_checkpoint_export \ | ||
| -v $HOME/.cache/huggingface/:/root/.cache/huggingface \ | ||
| esm2 python export.py | ||
| esm2 python export.py hf-to-te |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think we need to make folders for the te-to-hf path, we just need to show how to take an existing TE model and convert it to HF (e.g., through a python code demo). Folks likely wont convert our TE models back to HF
| ```bash | ||
| MODEL_TAG=esm2_t6_8M_UR50D # specify which model to convert | ||
| mkdir -p te_to_hf_checkpoint_export | ||
| docker build -t esm2 . | ||
| docker run --rm -it --gpus all \ | ||
| -v $PWD/te_to_hf_checkpoint_export/:/workspace/bionemo/te_to_hf_checkpoint_export \ | ||
| -v $PWD/hf_to_te_checkpoint_export/$MODEL_TAG:/workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG \ | ||
| -v $HOME/.cache/huggingface/:/root/.cache/huggingface \ | ||
| esm2 python export.py te-to-hf --checkpoint-path /workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this could be just something like
from esm.convert import convert_esm_te_to_hf
te_model = AutoModel.from_pretrained("/path/to/te/checkpoint", trust_remote_code=True)
hf_model = convert_esm_te_to_hf(te_model)
hf_model.save_pretrained("/path/to/exported/model")| from esm.export import export_hf_checkpoint | ||
|
|
||
| te_checkpoint_path = Path("te_checkpoint") | ||
| export_hf_checkpoint("esm2_t6_8M_UR50D", te_checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't as obvious a function name as convert_esm_te_to_hf, let's just keep the existing API and add the reverse path?
…version (#1218) Adds convert.py and test_convert.py from @ohadmo's PR and makes some readme updates Takes many changes from #1121 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added bidirectional conversion between Transformer Engine and HuggingFace ESM-2 formats, including TE→HF roundtrip support. * **Documentation** * Expanded ESM-2 guide with Python-centric conversion examples, "Load and Test" and "Validating Converted Models" sections, and step-by-step deployment/upload workflows. * **Tests** * Added end-to-end and unit tests validating round-trip conversions, parameter parity, config alignment, and padding/unpadding behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <[email protected]>
Adding TransformerEngine to HF checkpoint conversion for ESM2
Summary by CodeRabbit
New Features
Documentation
Tests
Chores