Skip to content

Conversation

@ohadmo
Copy link
Member

@ohadmo ohadmo commented Sep 5, 2025

Adding TransformerEngine to HF checkpoint conversion for ESM2

Summary by CodeRabbit

  • New Features

    • Two-way ESM-2 checkpoint conversion via a CLI (hf-to-te, te-to-hf) with selectable models/checkpoints and output paths; adds TE→HF round‑trip export support.
  • Documentation

    • README updated to mark HF export as supported and include a full end-to-end developer conversion and validation workflow with examples.
  • Tests

    • New tests covering HF↔TE roundtrips, weight parity, QKV unpacking, padding/unpadding, and config consistency.
  • Chores

    • Docker ignores split into separate HF→TE and TE→HF export directories.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Walkthrough

Adds a CLI with hf-to-te and te-to-hf subcommands, implements TE→HF conversion (reverse mapping and tensor transforms), adds export_te_checkpoint, updates README and .dockerignore, and adds tests validating HF↔TE roundtrips, QKV handling, padding, and config fidelity.

Changes

Cohort / File(s) Summary
CLI entrypoint
bionemo-recipes/models/esm2/export.py
Replaces prior iterative export with an argparse CLI exposing hf-to-te and te-to-hf subcommands; validates args and dispatches to export_hf_checkpoint or export_te_checkpoint, printing progress.
TE↔HF conversion core
bionemo-recipes/models/esm2/src/esm/convert.py
Adds convert_esm_te_to_hf and reverse_mapping; constructs HF config from TE config (removing TE-only keys), reverses parameter mapping, and applies transforms to unpack QKV, unpad embeddings/decoder/bias, tie weights, and return an HF model.
Export utilities
bionemo-recipes/models/esm2/src/esm/export.py
Adds export_te_checkpoint(te_checkpoint_path: str, output_path: Path) to load an NVEsm TE checkpoint, convert to HF via convert_esm_te_to_hf, save HF model and tokenizer/vocab assets, and perform a load/teardown smoke check.
Tests
bionemo-recipes/models/esm2/tests/test_convert.py, bionemo-recipes/models/esm2/tests/test_export.py
Adds tests for HF→TE→HF roundtrip, QKV unpacking, config preservation/removal, padding/unpadding consistency, and TE→HF export parity versus original HF weights.
Docs
bionemo-recipes/models/esm2/README.md
Updates export status, renames export paths, documents CLI usage, adds a Developer Conversion Workflow with HF↔TE roundtrip examples and validation, and adjusts Docker volume examples.
Build ignores
bionemo-recipes/models/esm2/.dockerignore
Replaces checkpoint_export/ with hf_to_te_checkpoint_export/ and te_to_hf_checkpoint_export/ ignore entries.

Sequence Diagram(s)

sequenceDiagram
  actor User
  participant CLI as export.py (CLI)
  participant HF as HuggingFace
  participant TE as TransformerEngine
  participant FS as Filesystem

  rect #F0F8FF
    note over CLI,HF: hf-to-te flow
    User->>CLI: hf-to-te [--model|-m]* --output-path
    CLI->>HF: load HF model(s)
    CLI->>TE: export_hf_checkpoint (HF → TE)
    TE->>FS: save TE checkpoint(s)
    CLI-->>User: status messages
  end

  rect #F7FFF0
    note over CLI,TE: te-to-hf flow
    User->>CLI: te-to-hf --checkpoint-path --output-path
    CLI->>FS: validate checkpoint path
    CLI->>TE: load TE checkpoint (NVEsmForMaskedLM)
    TE->>CLI: provide TE state & config
    CLI->>HF: convert_esm_te_to_hf (TE → HF)
    HF->>FS: save HF model + tokenizer/vocab
    CLI-->>User: status messages
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

In my burrow I hop through code and log,
I stitch QKV, unpad rows from fog.
CLI opens both paths day and night,
Roundtrips checked, weights snug and tight.
Hopping happy — export done right. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description consists of a single brief sentence and does not follow the repository’s prescribed template. It is missing required sections including a detailed description, usage examples, type of changes, CI pipeline configuration labels, and the pre-submit checklist. Please complete the repository’s pull request description template by adding a comprehensive description of the changes, usage code snippets, marking the type of change, specifying CI labels, and filling out the pre-submit checklist to meet contribution guidelines.
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly indicates the main addition of TransformerEngine-to-HuggingFace checkpoint conversion support for the ESM2 model, which is the primary feature introduced in this pull request. It is concise, specific, and fully aligned with the changeset’s core objective.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch omosafi/add-esm2-TE-to-HF

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f4abfba and 2cf12c6.

📒 Files selected for processing (3)
  • bionemo-recipes/models/esm2/README.md (2 hunks)
  • bionemo-recipes/models/esm2/src/esm/export.py (2 hunks)
  • bionemo-recipes/models/esm2/tests/test_export.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
  • export_hf_checkpoint (54-119)
  • export_te_checkpoint (122-163)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (48-73)
  • convert_esm_te_to_hf (76-137)
bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
  • NVEsmForMaskedLM (415-508)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from cfb5cae to f2f2247 Compare September 5, 2025 20:00
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (16)
models/esm2/export.py (3)

34-41: Drop redundant validation; argparse choices already enforce valid tags

choices=ESM_TAGS guarantees validity, so the manual check below is unreachable and can be removed.

-    if args.model:
-        if args.model not in ESM_TAGS:
-            print(f"Error: '{args.model}' is not a valid model tag.\nAvailable models: {', '.join(ESM_TAGS)}")
-            return
-        export_hf_checkpoint(args.model, Path(args.output_path))
+    if args.model:
+        export_hf_checkpoint(args.model, Path(args.output_path))

42-47: Create the output directory up front for clearer failures

Fail fast on permission/path issues rather than during export; harmless if it already exists.

 args = parser.parse_args()

+    # Ensure base output path exists
+    Path(args.output_path).mkdir(parents=True, exist_ok=True)

Also applies to: 49-49


58-59: Prefer logging over prints

Consider logging.info("Converting %s...", tag) for consistency and easier diagnostics. No functional change required.

models/esm2/export_te_checkpoint_to_hf.py (2)

41-43: Tighten CLI description

Minor copy edit (“HuggingFace Facebook … hosted on Hugging Face” is redundant).

-        description="Convert ESM2 models from Transformer Engine format back to HuggingFace Facebook ESM-2 format hosted on Hugging Face"
+        description="Convert ESM2 models from Transformer Engine (TE) format back to the original Hugging Face Facebook ESM-2 format"

22-38: Consider using logging instead of print and returning non-zero on failure

Replace prints with logging and sys.exit(1) in the except path to signal failure in scripts/CI.

models/esm2/src/esm/export.py (3)

113-121: Preserve architectures in config for better HF UX

Setting "architectures": ["EsmForMaskedLM"] helps downstream loading tools and Hub metadata.

     if config_path.exists():
         with open(config_path, "r") as f:
             config = json.load(f)
         config.pop("auto_map", None)
         config["model_type"] = "esm"
+        config.setdefault("architectures", ["EsmForMaskedLM"])
         with open(config_path, "w") as f:
             json.dump(config, f, indent=2, sort_keys=True)

123-130: Make the smoke test optional to avoid OOMs on large checkpoints (e.g., 3B/15B)

Loading the full model in bf16 can be prohibitive in limited environments. Gate behind a flag with a safe default.

-def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
+def export_te_checkpoint(te_checkpoint_path: str, output_path: str, *, smoke_test: bool = False):
@@
-    model_hf = AutoModelForMaskedLM.from_pretrained(
-        output_path,
-        torch_dtype=torch.bfloat16,
-        trust_remote_code=False,
-    )
-    del model_hf
-    gc.collect()
-    torch.cuda.empty_cache()
+    if smoke_test:
+        model_hf = AutoModelForMaskedLM.from_pretrained(
+            output_path,
+            torch_dtype=torch.bfloat16,
+            trust_remote_code=False,
+        )
+        del model_hf
+        gc.collect()
+        torch.cuda.empty_cache()

Callers (e.g., the new CLI) can pass smoke_test=True when resources allow.


100-111: Minor DRY: copy optional tokenizer artifacts via a loop

Reduces repetition; behavior unchanged.

-    tokenizer_path = Path(te_checkpoint_path) / "tokenizer.json"
-    if tokenizer_path.exists():
-        shutil.copy(tokenizer_path, Path(output_path) / "tokenizer.json")
-
-    tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
-    if tokenizer_config_path.exists():
-        shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")
-
-    vocab_path = Path(te_checkpoint_path) / "vocab.txt"
-    if vocab_path.exists():
-        shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
+    for fname in ("tokenizer.json", "tokenizer_config.json", "vocab.txt"):
+        src = Path(te_checkpoint_path) / fname
+        if src.exists():
+            shutil.copy(src, Path(output_path) / fname)
models/esm2/README.md (1)

53-58: Clarify exported path handling.

Consider documenting that export_hf_checkpoint writes to te_checkpoint_path / so users can substitute different tags without editing the path literal. A short note avoids confusion.

models/esm2/src/esm/convert.py (3)

115-118: Avoid mutating model structure unless necessary.

Deleting contact_head is safe for state_dict parity, but it’s not required for EsmForMaskedLM functionality. Prefer leaving the module intact and excluding its params via state_dict_ignored_entries to reduce surprises for downstream users introspecting model_hf.esm.

-        # Remove contact_head since it's not present in TE models
-        if hasattr(model_hf.esm, "contact_head"):
-            delattr(model_hf.esm, "contact_head")
+        # Keep contact_head module; its params are excluded below.

119-129: Consider ignoring non-parameter buffers as well.

If present, buffers like position_ids or inv_freq can be safely ignored to keep the transform strict on learnable weights only.

         state_dict_ignored_entries=[
             "lm_head.decoder.weight",
             "esm.contact_head.regression.weight",
             "esm.contact_head.regression.bias",
+            "esm.embeddings.position_ids",
+            "esm.encoder.layer.*.attention.self.rotary_emb.inv_freq",
         ],

148-160: Docstrings mention padding; update to describe QKV packing.

The functions pack QKV and do not pad. Update the docstrings to avoid confusion.

Would you like me to open a follow-up PR to update the docstrings and inline comments in _pack_qkv_weight/_pack_qkv_bias for accuracy?

Also applies to: 170-181

models/esm2/tests/test_convert_reverse.py (4)

25-42: Narrow the TE-layer check to encoder modules to avoid false positives.

Linear/LayerNorm may legitimately exist outside the fused TE encoder (e.g., embeddings/lm_head). Scope to ".encoder.layers." to reduce noise.

-    for name, module in model_te.named_modules():
+    for name, module in model_te.named_modules():
+        if ".encoder.layers." not in name:
+            continue
         if isinstance(module, nn.Linear):
             vanilla_layers_found.append(f"Linear layer found in {name}")
         if isinstance(module, nn.LayerNorm):
             vanilla_layers_found.append(f"LayerNorm layer found in {name}")

64-81: Mark export round-trip as slow/internet and guard TE import.

This test hits disk and remote hub; mark and guard to play well in CI.

-@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
+@requires_internet
+@pytest.mark.slow
+@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
 def test_export_te_checkpoint_to_hf(model_name):
@@
-    from esm.export import export_hf_checkpoint, export_te_checkpoint
+    from esm.export import export_hf_checkpoint, export_te_checkpoint
+    pytest.importorskip("transformer_engine", reason="Transformer Engine required for TE checkpoint export")

119-132: Strengthen config assertions; ensure TE-only keys are stripped.

Enable negative assertions to lock in expected behavior from convert_esm_te_to_hf.

     assert model_hf_converted.config.vocab_size == model_hf.config.vocab_size
-
-    # assert not hasattr(model_hf_converted.config, 'qkv_weight_interleaved')
-    # assert not hasattr(model_hf_converted.config, 'encoder_activation')
-    # assert not hasattr(model_hf_converted.config, 'attn_input_format')
-    # assert not hasattr(model_hf_converted.config, 'fuse_qkv_params')
-    # assert not hasattr(model_hf_converted.config, 'micro_batch_size')
-    # assert not hasattr(model_hf_converted.config, 'max_seq_length')
+    assert not hasattr(model_hf_converted.config, "qkv_weight_interleaved")
+    assert not hasattr(model_hf_converted.config, "encoder_activation")
+    assert not hasattr(model_hf_converted.config, "attn_input_format")
+    assert not hasattr(model_hf_converted.config, "fuse_qkv_params")
+    assert not hasattr(model_hf_converted.config, "micro_batch_size")
+    assert not hasattr(model_hf_converted.config, "max_seq_length")

59-62: Tolerance is tight but fine for fp32; consider dtype harmonization to avoid precision flakes.

If CI flakes occur, cast both state_dicts to float32 before comparisons to avoid bfloat16/float16 noise.

I can add a helper to clone-and-cast tensors before torch.testing.assert_close; want me to send a small patch?

Also applies to: 92-95, 114-116

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 3ffd142 and f2f2247.

📒 Files selected for processing (6)
  • models/esm2/README.md (1 hunks)
  • models/esm2/export.py (2 hunks)
  • models/esm2/export_te_checkpoint_to_hf.py (1 hunks)
  • models/esm2/src/esm/convert.py (4 hunks)
  • models/esm2/src/esm/export.py (2 hunks)
  • models/esm2/tests/test_convert_reverse.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
models/esm2/export_te_checkpoint_to_hf.py (2)
models/esm2/src/esm/export.py (1)
  • export_te_checkpoint (76-130)
models/esm2/export.py (1)
  • main (32-59)
models/esm2/src/esm/export.py (1)
models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (51-76)
  • convert_esm_te_to_hf (79-137)
models/esm2/tests/test_convert_reverse.py (2)
models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (51-76)
  • convert_esm_te_to_hf (79-137)
models/esm2/src/esm/export.py (2)
  • export_hf_checkpoint (28-73)
  • export_te_checkpoint (76-130)
models/esm2/export.py (1)
models/esm2/src/esm/export.py (1)
  • export_hf_checkpoint (28-73)
🔇 Additional comments (3)
models/esm2/src/esm/export.py (1)

24-26: Imports for reverse conversion look correct

Brings in convert_esm_te_to_hf and NVEsmForMaskedLM needed for TE→HF path.

models/esm2/src/esm/convert.py (2)

20-20: Good addition: explicit HF types.

Importing EsmConfig/EsmForMaskedLM locally removes reliance on AutoModel classes here.


47-49: Reverse mapping is fine.

The simple inversion matches the one-to-one entries; QKV is handled via transforms.

Comment on lines 16 to 22
import tempfile
from pathlib import Path

import pytest
import torch
from torch import nn
from transformers import AutoModelForMaskedLM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add imports/markers for optional dependencies and offline CI.

Guard tests that require internet/TE, and avoid surprises in offline CI.

 import tempfile
 from pathlib import Path
 
 import pytest
 import torch
 from torch import nn
 from transformers import AutoModelForMaskedLM
+import os
+
+# Skip if offline (HF env flags)
+OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE")
+requires_internet = pytest.mark.skipif(OFFLINE, reason="Requires internet to download HF models")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import tempfile
from pathlib import Path
import pytest
import torch
from torch import nn
from transformers import AutoModelForMaskedLM
import tempfile
from pathlib import Path
import pytest
import torch
from torch import nn
from transformers import AutoModelForMaskedLM
import os
# Skip if offline (HF env flags)
OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE")
requires_internet = pytest.mark.skipif(
OFFLINE, reason="Requires internet to download HF models"
)
🤖 Prompt for AI Agents
In models/esm2/tests/test_convert_reverse.py around lines 16 to 23, the test
imports assume optional dependencies (internet access, TorchEngine/TE) are
available; add guards so the test is skipped in offline CI or when optional
packages are missing. Use pytest.importorskip for optional libraries (or wrap
imports in try/except and call pytest.skip), and add appropriate
pytest.mark.skipif conditions (e.g., skip when an environment variable indicates
offline CI or when required network access is unavailable) so the test only runs
when dependencies and network are present.

Copy link
Collaborator

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make sure you rebase and run CI?

Comment on lines 25 to 41
def test_esm_model_has_all_te_layers():
"""Test that the converted TE model doesn't contain vanilla PyTorch layers."""
from esm.convert import convert_esm_hf_to_te

model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
model_te = convert_esm_hf_to_te(model_hf)
vanilla_layers_found = []
for name, module in model_te.named_modules():
if isinstance(module, nn.Linear):
vanilla_layers_found.append(f"Linear layer found in {name}")
if isinstance(module, nn.LayerNorm):
vanilla_layers_found.append(f"LayerNorm layer found in {name}")
if vanilla_layers_found:
print("ERROR: Found vanilla PyTorch layers in converted TE model:")
for error in vanilla_layers_found:
print(f"WARNING: {error}")
assert not vanilla_layers_found, f"Found {len(vanilla_layers_found)} vanilla layers in converted model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a duplicate test, right? we have test_esm_model_has_all_te_layers in test_modeling_esm_te? do we need this in a test that exports from the TE model to HF?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this test is redundant, I removed it

Comment on lines 134 to 139
# assert not hasattr(model_hf_converted.config, 'qkv_weight_interleaved')
# assert not hasattr(model_hf_converted.config, 'encoder_activation')
# assert not hasattr(model_hf_converted.config, 'attn_input_format')
# assert not hasattr(model_hf_converted.config, 'fuse_qkv_params')
# assert not hasattr(model_hf_converted.config, 'micro_batch_size')
# assert not hasattr(model_hf_converted.config, 'max_seq_length')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason these are commented out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, forgot to uncomment that. I rewrote this section

@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from 22c7739 to 27365b0 Compare September 10, 2025 23:34
Signed-off-by: Ohad Mosafi <[email protected]>
@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from 27365b0 to 3af0c7d Compare September 10, 2025 23:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
models/esm2/tests/test_convert_reverse.py (1)

16-23: Skip gracefully when offline or TE is unavailable.

These tests download HF models and rely on Transformer Engine; add guards to avoid CI flakes.

 import tempfile
 from pathlib import Path
 
 import pytest
 import torch
 from transformers import AutoModelForMaskedLM
+import os
+from importlib.util import find_spec
+
+# Skip if offline (HF env flags)
+OFFLINE = os.environ.get("TRANSFORMERS_OFFLINE") or os.environ.get("HF_HUB_OFFLINE")
+requires_internet = pytest.mark.skipif(OFFLINE, reason="Requires internet to download HF models")
+
+# Skip if Transformer Engine is missing
+TE_MISSING = find_spec("transformer_engine") is None
+requires_te = pytest.mark.skipif(TE_MISSING, reason="Requires transformer_engine")
🧹 Nitpick comments (7)
models/esm2/export_te_checkpoint_to_hf.py (3)

26-27: Validate the checkpoint path is a directory (not just exists).

Avoid confusing errors when a file path is passed.

-    if not Path(te_checkpoint_dir).exists():
-        raise FileNotFoundError(f"TE checkpoint {te_checkpoint_dir} not found")
+    te_path = Path(te_checkpoint_dir)
+    if not te_path.exists():
+        raise FileNotFoundError(f"TE checkpoint {te_checkpoint_dir} not found")
+    if not te_path.is_dir():
+        raise NotADirectoryError(f"TE checkpoint path must be a directory: {te_checkpoint_dir}")

33-35: Preserve original traceback when re-raising.

Use a bare raise to keep the original stack.

-    except Exception as e:
-        print(f"Error converting {te_checkpoint_dir}: {e}")
-        raise e
+    except Exception as e:
+        print(f"Error converting {te_checkpoint_dir}: {e}")
+        raise

45-49: Clarify that --model expects a directory.

Small wording tweak improves UX.

-        help="Path to the TE checkpoint.",
+        help="Path to the TE checkpoint directory.",
models/esm2/tests/test_convert_reverse.py (4)

24-27: Mark test as requiring internet and TE.

- def test_convert_te_to_hf_roundtrip():
+@requires_internet
+@requires_te
+def test_convert_te_to_hf_roundtrip():

44-47: Mark test as requiring internet and TE.

-@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
-def test_export_te_checkpoint_to_hf(model_name):
+@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"])
+@requires_internet
+@requires_te
+def test_export_te_checkpoint_to_hf(model_name):

77-80: Mark test as requiring internet and TE.

-def test_qkv_unpacking():
+@requires_internet
+@requires_te
+def test_qkv_unpacking():

99-103: Mark test as requiring internet and TE.

-def test_config_conversion():
+@requires_internet
+@requires_te
+def test_config_conversion():
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f2f2247 and 27365b0.

📒 Files selected for processing (6)
  • models/esm2/README.md (1 hunks)
  • models/esm2/export.py (2 hunks)
  • models/esm2/export_te_checkpoint_to_hf.py (1 hunks)
  • models/esm2/src/esm/convert.py (5 hunks)
  • models/esm2/src/esm/export.py (2 hunks)
  • models/esm2/tests/test_convert_reverse.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • models/esm2/export.py
  • models/esm2/src/esm/export.py
  • models/esm2/README.md
  • models/esm2/src/esm/convert.py
🧰 Additional context used
🧬 Code graph analysis (2)
models/esm2/tests/test_convert_reverse.py (2)
models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (51-76)
  • convert_esm_te_to_hf (79-140)
models/esm2/src/esm/export.py (2)
  • export_hf_checkpoint (28-73)
  • export_te_checkpoint (76-118)
models/esm2/export_te_checkpoint_to_hf.py (1)
models/esm2/src/esm/export.py (1)
  • export_te_checkpoint (76-118)
🔇 Additional comments (1)
models/esm2/export_te_checkpoint_to_hf.py (1)

16-19: Ensure 'esm' import works when running the script directly

File: models/esm2/export_te_checkpoint_to_hf.py — Repo search found no esm/export.py or convert.py and no PYTHONPATH/CI hints beyond pyproject addopts. Running python models/esm2/export_te_checkpoint_to_hf.py may raise ModuleNotFoundError: esm unless models/esm2/src (or the package root) is on PYTHONPATH or the package is installed. Options: export PYTHONPATH in CI/docs, make the script runnable via python -m or an installed entrypoint, or add a short sys.path fallback in the script.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
models/esm2/src/esm/convert.py (1)

97-113: Map max_seq_length→max_position_embeddings and drop export-only dtype; safer EsmConfig.

Right now max_seq_length is popped but not mapped; if TE configs don't already carry max_position_embeddings, HF will fall back to defaults. Also strip torch_dtype to avoid passing runtime dtype hints into config construction.

Apply:

-    # Remove TE-specific config options
-    te_specific_keys = [
+    # Map TE naming to HF before cleanup
+    if "max_position_embeddings" not in hf_config_dict and "max_seq_length" in hf_config_dict:
+        hf_config_dict["max_position_embeddings"] = hf_config_dict["max_seq_length"]
+
+    # Remove TE-specific and export-only config options
+    te_specific_keys = [
         "qkv_weight_interleaved",
         "encoder_activation",
         "attn_input_format",
         "fuse_qkv_params",
         "micro_batch_size",
         "max_seq_length",
         "model_type",
         "auto_map",
+        "torch_dtype",
     ]
     for key in te_specific_keys:
         hf_config_dict.pop(key, None)
 
     hf_config_dict["model_type"] = "esm"
🧹 Nitpick comments (6)
models/esm2/src/esm/convert.py (4)

151-163: Make QKV weight packing shape-safe and fix the docstring.

The current code uses -1 and a “Pad …” docstring, which is misleading for a pack operation. Add explicit shape checks to mirror the safety of the unpack path.

Apply:

 @io.state_transform(
@@
 def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
-    """Pad the embedding layer to the new input dimension."""
-    concat_weights = torch.cat((query, key, value), dim=0)
-    input_shape = concat_weights.size()
-    num_heads = ctx.target.config.num_attention_heads
-    # transpose weights
-    # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
-    # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
-    concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1])
-    concat_weights = concat_weights.transpose(0, 1).contiguous()
-    concat_weights = concat_weights.view(*input_shape)
-    return concat_weights
+    """Pack separate Q,K,V weights into TE's fused [hidden_size*3, input_dim] layout."""
+    concat = torch.cat((query, key, value), dim=0)  # [3*hidden_size, input_dim]
+    num_heads = ctx.target.config.num_attention_heads
+    total_rows, input_dim = concat.size()
+    if total_rows % (3 * num_heads) != 0:
+        raise ValueError(f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}")
+    head_dim = total_rows // (3 * num_heads)
+    fused = concat.view(3, num_heads, head_dim, input_dim).transpose(0, 1).contiguous()
+    return fused.reshape(total_rows, input_dim)

173-185: Mirror the packing safety for biases and correct the docstring.

Same rationale as weights.

Apply:

 def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
-    """Pad the embedding layer to the new input dimension."""
-    concat_biases = torch.cat((query, key, value), dim=0)
-    input_shape = concat_biases.size()
-    num_heads = ctx.target.config.num_attention_heads
-    # transpose biases
-    # [num_splits_model_parallel * attention head size * #attention heads]
-    # --> [attention head size * num_splits_model_parallel * #attention heads]
-    concat_biases = concat_biases.view(3, num_heads, -1)
-    concat_biases = concat_biases.transpose(0, 1).contiguous()
-    concat_biases = concat_biases.view(*input_shape)
-    return concat_biases
+    """Pack separate Q,K,V biases into TE's fused [hidden_size*3] layout."""
+    concat = torch.cat((query, key, value), dim=0)  # [3*hidden_size]
+    num_heads = ctx.target.config.num_attention_heads
+    total = concat.size(0)
+    if total % (3 * num_heads) != 0:
+        raise ValueError(f"QKV bias size {total} not divisible by 3*num_heads {3*num_heads}")
+    head_dim = total // (3 * num_heads)
+    fused = concat.view(3, num_heads, head_dim).transpose(0, 1).contiguous()
+    return fused.reshape(total)

195-211: Prefer explicit ValueError over assert for shape validation.

Asserts can be stripped with Python -O. Use explicit checks for reliability.

Apply:

-    assert total_rows % (3 * num_heads) == 0, (
-        f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}"
-    )
+    if total_rows % (3 * num_heads) != 0:
+        raise ValueError(
+            f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}"
+        )

222-238: Same: raise ValueError instead of assert for bias shape checks.

Apply:

-    assert total_size % (3 * num_heads) == 0, (
-        f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}"
-    )
+    if total_size % (3 * num_heads) != 0:
+        raise ValueError(
+            f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}"
+        )
models/esm2/export_te_checkpoint_to_hf.py (1)

22-38: Avoid noisy try/except or re-raise plainly; optionally guard against same input/output path.

  • Printing and re-raising adds noise without better diagnostics. Either let exceptions bubble or raise without binding.
  • Prevent accidental mixing of TE and HF artifacts by forbidding identical input/output directories.

Apply:

 def main(te_checkpoint_dir: str, output_dir: str):
@@
-    try:
-        export_te_checkpoint(te_checkpoint_dir, output_dir)
-        print(f"Successfully exported {te_checkpoint_dir} to HuggingFace Facebook ESM-2 format at {output_dir}")
-    except Exception as e:
-        print(f"Error converting {te_checkpoint_dir}: {e}")
-        raise e
+    if Path(te_checkpoint_dir).resolve() == Path(output_dir).resolve():
+        raise ValueError("output_dir must differ from te_checkpoint_dir to avoid mixing artifacts.")
+
+    export_te_checkpoint(te_checkpoint_dir, output_dir)
+    print(f"Successfully exported {te_checkpoint_dir} to HuggingFace Facebook ESM-2 format at {output_dir}")
models/esm2/src/esm/export.py (1)

76-118: Reduce memory during smoke test load.

Loading the converted HF model just to validate can be made lighter to avoid transient OOMs in CI.

Apply:

-    model_hf = AutoModelForMaskedLM.from_pretrained(
-        output_path,
-        torch_dtype=torch.bfloat16,
-        trust_remote_code=False,
-    )
+    model_hf = AutoModelForMaskedLM.from_pretrained(
+        output_path,
+        torch_dtype=torch.bfloat16,
+        trust_remote_code=False,
+        low_cpu_mem_usage=True,
+        device_map="cpu",
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27365b0 and 3af0c7d.

📒 Files selected for processing (5)
  • models/esm2/README.md (1 hunks)
  • models/esm2/export_te_checkpoint_to_hf.py (1 hunks)
  • models/esm2/src/esm/convert.py (5 hunks)
  • models/esm2/src/esm/export.py (2 hunks)
  • models/esm2/tests/test_convert_reverse.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • models/esm2/tests/test_convert_reverse.py
  • models/esm2/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
models/esm2/export_te_checkpoint_to_hf.py (2)
models/esm2/src/esm/export.py (1)
  • export_te_checkpoint (76-117)
models/esm2/export.py (1)
  • main (32-59)
models/esm2/src/esm/export.py (1)
models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (51-76)
  • convert_esm_te_to_hf (79-140)
🔇 Additional comments (2)
models/esm2/src/esm/convert.py (1)

47-49: Reverse mapping looks correct and scoped well.

Simple and reliable way to derive TE→HF key mapping; avoids drift.

models/esm2/export_te_checkpoint_to_hf.py (1)

45-55: CLI help is clear; required model arg prevents Path(None) crash.

No further changes needed.

@pstjohn pstjohn added this to the LHA-ready milestone Sep 18, 2025
@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from 2073094 to c59c7ba Compare September 30, 2025 23:49
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
bionemo-recipes/models/esm2/src/esm/export.py (1)

140-140: Consider explicit dtype for consistency.

The TE model is loaded without specifying a dtype, while the smoke test at line 158 uses torch.bfloat16. For consistency with export_hf_checkpoint (line 114) and to avoid potential dtype mismatches, consider specifying the dtype explicitly.

Apply this diff:

-    model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path)
+    model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path, torch_dtype=torch.bfloat16)
bionemo-recipes/models/esm2/export.py (2)

67-69: Remove redundant validation.

This validation is unnecessary because argparse with choices=ESM_TAGS already ensures that args.model is valid. The code at line 66 only executes when args.model is not None and has passed the choices validation.

Apply this diff:

         if args.model:
-            if args.model not in ESM_TAGS:
-                print(f"Error: '{args.model}' is not a valid model tag.\nAvailable models: {', '.join(ESM_TAGS)}")
-                return
             print(f"Converting {args.model} from HuggingFace to Transformer Engine format")
             export_hf_checkpoint(args.model, Path(args.output_path))

78-79: Remove duplicate path validation.

This path existence check duplicates the validation already performed in export_te_checkpoint at lines 134-135 of bionemo-recipes/models/esm2/src/esm/export.py. The function will raise a clear FileNotFoundError if the path doesn't exist.

Apply this diff:

         print(f"Converting {args.checkpoint_path} from Transformer Engine to HuggingFace format")
-        if not Path(args.checkpoint_path).exists():
-            raise FileNotFoundError(f"TE checkpoint {args.checkpoint_path} not found")
         export_te_checkpoint(args.checkpoint_path, Path(args.output_path))
bionemo-recipes/models/esm2/tests/test_convert_reverse.py (1)

67-67: Remove commented code.

The commented assertion at line 67 is replaced by the logic at lines 68-70. Remove the dead code.

Apply this diff:

-        # assert original_state_dict.keys() == exported_state_dict.keys()
         original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3af0c7d and c59c7ba.

📒 Files selected for processing (4)
  • bionemo-recipes/models/esm2/export.py (2 hunks)
  • bionemo-recipes/models/esm2/src/esm/convert.py (7 hunks)
  • bionemo-recipes/models/esm2/src/esm/export.py (2 hunks)
  • bionemo-recipes/models/esm2/tests/test_convert_reverse.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/models/esm2/export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)
  • export_te_checkpoint (122-163)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (1)
  • convert_esm_te_to_hf (76-137)
bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
  • NVEsmForMaskedLM (415-508)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (22)
bionemo-recipes/models/esm2/src/esm/export.py (5)

25-26: LGTM!

The new imports support the TE→HF conversion functionality.


134-135: LGTM!

Appropriate validation of input path existence.


144-154: LGTM!

Optional tokenizer file copying is appropriate and handles missing files gracefully.


156-163: Smoke test validates model loading.

The smoke test successfully verifies that the exported HF model can be loaded with the standard Transformers API. Memory cleanup is properly performed.


142-142: Inspect save_pretrained via Python

Run the following to dump its source and verify directory‐creation behavior:

#!/bin/bash
python - << 'EOF'
import inspect
from transformers import PreTrainedModel
print(inspect.getsource(PreTrainedModel.save_pretrained))
EOF

This will reveal whether missing output directories are auto‐created; if not, add an explicit os.makedirs(output_path, exist_ok=True) before save_pretrained.

bionemo-recipes/models/esm2/export.py (2)

16-19: LGTM!

Imports support the new bidirectional CLI functionality.


65-80: Verify argument name consistency after fix.

If the argument naming is standardized to use hyphens (per previous comment), ensure attribute access uses underscores: args.output_path becomes args.output_path (argparse converts hyphens to underscores in attribute names).

bionemo-recipes/models/esm2/tests/test_convert_reverse.py (6)

24-42: LGTM!

Comprehensive roundtrip test with appropriate exclusions for contact_head (not preserved in TE), _extra_state, and inv_freq buffers. Tolerance settings are reasonable for float precision.


44-75: Comprehensive export roundtrip validation.

The test properly validates the full export workflow including checkpoint persistence and reloading. Temp directory usage ensures test isolation.


77-97: LGTM!

Thorough validation of QKV unpacking across all transformer layers with appropriate numerical tolerance.


99-126: LGTM!

Comprehensive config conversion validation ensuring field preservation and proper cleanup of TE-specific fields.


160-163: Padding validation is thorough.

The test correctly validates that padded embedding rows are zero-initialized. Note that the comment at line 160 mentions "min values (for bias)" but the test only validates embedding padding; bias padding validation could be added if needed.


128-164: LGTM!

Comprehensive validation of padding/unpadding operations for embeddings, decoder weights, and biases with proper shape and value checks.

bionemo-recipes/models/esm2/src/esm/convert.py (9)

20-20: LGTM!

Required imports for HF model reconstruction in the TE→HF conversion path.


44-45: LGTM!

Clean reversal of the mapping dictionary to support TE→HF parameter name translation.


76-137: Well-structured TE→HF conversion.

The function properly handles config conversion, model initialization, and parameter transformation. The removal of TE-specific config keys and contact_head is appropriate since contact_head is not preserved in TE models.


148-159: LGTM!

Properly retrieves num_heads from the target config for QKV weight packing.


170-181: LGTM!

Consistent with QKV weight packing, properly uses config for num_heads.


184-208: LGTM!

Correct unpacking logic that inverts the QKV fusion performed in _pack_qkv_weight. The shape assertion provides good validation.


211-235: LGTM!

Correct bias unpacking that mirrors the weight unpacking logic and inverts _pack_qkv_bias.


238-273: LGTM!

Clean implementation of unpadding logic that correctly inverts the padding operations. Appropriate application to both embeddings and decoder weights.


291-298: LGTM!

Correct unpadding of bias tensor, inverting the padding operation from _pad_bias.

Signed-off-by: Ohad Mosafi <[email protected]>
@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from 1ae584e to b6af4e8 Compare October 2, 2025 02:21
Signed-off-by: Ohad Mosafi <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c59c7ba and ea890a2.

📒 Files selected for processing (5)
  • bionemo-recipes/models/esm2/.dockerignore (1 hunks)
  • bionemo-recipes/models/esm2/README.md (2 hunks)
  • bionemo-recipes/models/esm2/src/esm/export.py (2 hunks)
  • bionemo-recipes/models/esm2/tests/test_convert.py (1 hunks)
  • bionemo-recipes/models/esm2/tests/test_export.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
bionemo-recipes/models/esm2/tests/test_convert.py (1)
bionemo-recipes/models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (48-73)
  • convert_esm_te_to_hf (76-137)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (1)
  • convert_esm_te_to_hf (76-137)
bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
  • NVEsmForMaskedLM (415-508)
bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
  • export_hf_checkpoint (54-119)
  • export_te_checkpoint (122-163)
🔇 Additional comments (1)
bionemo-recipes/models/esm2/.dockerignore (1)

3-4: Ignore rules look good

Keeping the TE↔HF export outputs out of build context avoids Docker bloat.

Comment on lines 144 to 154
tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
if tokenizer_config_path.exists():
shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")

vocab_path = Path(te_checkpoint_path) / "vocab.txt"
if vocab_path.exists():
shutil.copy(vocab_path, Path(output_path) / "vocab.txt")

special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
if special_tokens_path.exists():
shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Copy all tokenizer artifacts during TE→HF export

Only copying tokenizer_config.json, vocab.txt, and special_tokens_map.json drops other tokenizer assets (tokenizer.json, added_tokens.json, merges.txt, etc.). If the TE checkpoint was trained with added tokens or relies on the fast tokenizer, the exported HF checkpoint will be broken or silently lose vocabulary entries. Please copy every tokenizer file alongside the model weights.

-    tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
-    if tokenizer_config_path.exists():
-        shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")
-
-    vocab_path = Path(te_checkpoint_path) / "vocab.txt"
-    if vocab_path.exists():
-        shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
-
-    special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
-    if special_tokens_path.exists():
-        shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json")
+    tokenizer_files = [
+        "tokenizer.json",
+        "tokenizer_config.json",
+        "special_tokens_map.json",
+        "added_tokens.json",
+        "vocab.txt",
+        "merges.txt",
+    ]
+    for filename in tokenizer_files:
+        src = Path(te_checkpoint_path) / filename
+        if src.exists():
+            shutil.copy(src, Path(output_path) / filename)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
if tokenizer_config_path.exists():
shutil.copy(tokenizer_config_path, Path(output_path) / "tokenizer_config.json")
vocab_path = Path(te_checkpoint_path) / "vocab.txt"
if vocab_path.exists():
shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
if special_tokens_path.exists():
shutil.copy(special_tokens_path, Path(output_path) / "special_tokens_map.json")
tokenizer_files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"vocab.txt",
"merges.txt",
]
for filename in tokenizer_files:
src = Path(te_checkpoint_path) / filename
if src.exists():
shutil.copy(src, Path(output_path) / filename)

@ohadmo ohadmo requested a review from pstjohn October 2, 2025 02:49
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/export.py (1)

62-62: Remove trailing whitespace.

Line 62 has trailing whitespace after the closing parenthesis.

Apply this diff:

-    )
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea890a2 and 05fbec7.

📒 Files selected for processing (2)
  • bionemo-recipes/models/esm2/export.py (2 hunks)
  • bionemo-recipes/models/esm2/tests/test_export.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • bionemo-recipes/models/esm2/tests/test_export.py
🧰 Additional context used
🧬 Code graph analysis (1)
bionemo-recipes/models/esm2/export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)
  • export_te_checkpoint (122-163)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (2)
bionemo-recipes/models/esm2/export.py (2)

16-19: LGTM! Imports are correct.

The imports are properly organized and align with the module structure. Both export_hf_checkpoint and export_te_checkpoint are correctly imported from esm.export.


38-61: LGTM! Previous issues resolved.

The argument naming inconsistency and incorrect help text from the previous review have been properly addressed:

  • Both subparsers now consistently use --output-path (hyphen style)
  • The te-to-hf help text correctly identifies the checkpoint as "Transformer Engine"

@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from 7fb6169 to f4abfba Compare October 2, 2025 16:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
bionemo-recipes/models/esm2/src/esm/export.py (1)

144-154: Copy all tokenizer artifacts during TE→HF export.

Only copying tokenizer_config.json, vocab.txt, and special_tokens_map.json drops other tokenizer assets (tokenizer.json, added_tokens.json, merges.txt, etc.). If the TE checkpoint includes added tokens or relies on the fast tokenizer, the exported HF checkpoint will be broken or silently lose vocabulary entries.

Apply this diff to copy all tokenizer files:

-    tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json"
-    if tokenizer_config_path.exists():
-        shutil.copy(tokenizer_config_path, output_path / "tokenizer_config.json")
-
-    vocab_path = Path(te_checkpoint_path) / "vocab.txt"
-    if vocab_path.exists():
-        shutil.copy(vocab_path, output_path / "vocab.txt")
-
-    special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json"
-    if special_tokens_path.exists():
-        shutil.copy(special_tokens_path, output_path / "special_tokens_map.json")
+    tokenizer_files = [
+        "tokenizer.json",
+        "tokenizer_config.json",
+        "special_tokens_map.json",
+        "added_tokens.json",
+        "vocab.txt",
+        "merges.txt",
+    ]
+    for filename in tokenizer_files:
+        src = Path(te_checkpoint_path) / filename
+        if src.exists():
+            shutil.copy(src, output_path / filename)
🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/tests/test_export.py (1)

74-103: Expand test coverage for TE-to-HF export.

The test validates basic roundtrip conversion but lacks coverage for:

  • Tokenizer file preservation (the past review comment flagged missing tokenizer files)
  • Config attribute validation beyond state_dict comparison
  • Error handling when TE checkpoint is missing or corrupted
  • Multiple model sizes (currently only tests esm2_t6_8M_UR50D)

Consider adding these test cases:

@pytest.mark.parametrize("model_name", [
    "esm2_t6_8M_UR50D",
    "esm2_t12_35M_UR50D",  # Add more model sizes
])
def test_export_te_checkpoint_to_hf(model_name):
    # ... existing test code ...
    
    # Add tokenizer validation
    original_tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}")
    exported_tokenizer = AutoTokenizer.from_pretrained(str(hf_export_path))
    assert original_tokenizer.vocab == exported_tokenizer.vocab
    
def test_export_te_checkpoint_missing_path():
    """Test error handling for missing TE checkpoint."""
    from esm.export import export_te_checkpoint
    
    with pytest.raises(FileNotFoundError):
        export_te_checkpoint("/nonexistent/path", Path("/tmp/output"))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 05fbec7 and f4abfba.

📒 Files selected for processing (3)
  • bionemo-recipes/models/esm2/README.md (2 hunks)
  • bionemo-recipes/models/esm2/src/esm/export.py (2 hunks)
  • bionemo-recipes/models/esm2/tests/test_export.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/models/esm2/src/esm/export.py (2)
bionemo-recipes/models/esm2/src/esm/convert.py (2)
  • convert_esm_hf_to_te (48-73)
  • convert_esm_te_to_hf (76-137)
bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
  • NVEsmForMaskedLM (415-508)
bionemo-recipes/models/esm2/tests/test_export.py (1)
bionemo-recipes/models/esm2/src/esm/export.py (2)
  • export_hf_checkpoint (54-119)
  • export_te_checkpoint (122-163)
🔇 Additional comments (6)
bionemo-recipes/models/esm2/README.md (2)

80-99: LGTM! Clear conversion commands.

The updated conversion commands correctly reflect the new CLI structure with hf-to-te and te-to-hf subcommands, and the volume mount paths are properly separated for each conversion direction.


101-181: LGTM! Comprehensive developer workflow.

The developer workflow section provides a complete end-to-end example with proper validation steps. The code examples correctly demonstrate:

  • Loading original HF models
  • Exporting to TE format
  • Converting back to HF format
  • Validating equivalence with numerical comparison
bionemo-recipes/models/esm2/src/esm/export.py (4)

25-26: LGTM! Correct imports for TE-to-HF conversion.

The imports correctly bring in the conversion function and TE model class needed for the export functionality.


122-133: LGTM! Clear function signature and documentation.

The docstring clearly explains the function's purpose and the distinction between TE format (also HF-compatible) and the original Facebook ESM-2 format.


134-142: LGTM! Proper error handling and conversion flow.

The function correctly checks for checkpoint existence, loads the TE model, converts it, and saves the result.


156-163: LGTM! Proper smoke test and cleanup.

The function correctly validates that the exported model can be loaded and properly cleans up resources with garbage collection and CUDA cache clearing.

Signed-off-by: Ohad Mosafi <[email protected]>
@ohadmo ohadmo force-pushed the omosafi/add-esm2-TE-to-HF branch from f4abfba to 2cf12c6 Compare October 2, 2025 17:02
Comment on lines +80 to +85
mkdir -p hf_to_te_checkpoint_export
docker build -t esm2 .
docker run --rm -it --gpus all \
-v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \
-v $PWD/hf_to_te_checkpoint_export/:/workspace/bionemo/hf_to_te_checkpoint_export \
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
esm2 python export.py
esm2 python export.py hf-to-te
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think we need to make folders for the te-to-hf path, we just need to show how to take an existing TE model and convert it to HF (e.g., through a python code demo). Folks likely wont convert our TE models back to HF

Comment on lines +90 to +99
```bash
MODEL_TAG=esm2_t6_8M_UR50D # specify which model to convert
mkdir -p te_to_hf_checkpoint_export
docker build -t esm2 .
docker run --rm -it --gpus all \
-v $PWD/te_to_hf_checkpoint_export/:/workspace/bionemo/te_to_hf_checkpoint_export \
-v $PWD/hf_to_te_checkpoint_export/$MODEL_TAG:/workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG \
-v $HOME/.cache/huggingface/:/root/.cache/huggingface \
esm2 python export.py te-to-hf --checkpoint-path /workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this could be just something like

from esm.convert import convert_esm_te_to_hf

te_model = AutoModel.from_pretrained("/path/to/te/checkpoint", trust_remote_code=True)
hf_model = convert_esm_te_to_hf(te_model)
hf_model.save_pretrained("/path/to/exported/model")

from esm.export import export_hf_checkpoint

te_checkpoint_path = Path("te_checkpoint")
export_hf_checkpoint("esm2_t6_8M_UR50D", te_checkpoint_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't as obvious a function name as convert_esm_te_to_hf, let's just keep the existing API and add the reverse path?

pstjohn added a commit that referenced this pull request Oct 9, 2025
…version (#1218)

Adds convert.py and test_convert.py from @ohadmo's PR and makes some
readme updates

Takes many changes from #1121

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added bidirectional conversion between Transformer Engine and
HuggingFace ESM-2 formats, including TE→HF roundtrip support.

* **Documentation**
* Expanded ESM-2 guide with Python-centric conversion examples, "Load
and Test" and "Validating Converted Models" sections, and step-by-step
deployment/upload workflows.

* **Tests**
* Added end-to-end and unit tests validating round-trip conversions,
parameter parity, config alignment, and padding/unpadding behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Peter St. John <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants