diff --git a/bionemo-recipes/models/esm2/.dockerignore b/bionemo-recipes/models/esm2/.dockerignore index 7766311b5e..6d488a2d0c 100644 --- a/bionemo-recipes/models/esm2/.dockerignore +++ b/bionemo-recipes/models/esm2/.dockerignore @@ -1,3 +1,4 @@ Dockerfile README.md -checkpoint_export/ +hf_to_te_checkpoint_export/ +te_to_hf_checkpoint_export/ diff --git a/bionemo-recipes/models/esm2/README.md b/bionemo-recipes/models/esm2/README.md index 07a2bf4efe..003a8ae94c 100644 --- a/bionemo-recipes/models/esm2/README.md +++ b/bionemo-recipes/models/esm2/README.md @@ -16,7 +16,7 @@ The ESM-2 implementation natively supports the following TransformerEngine-provi | **Sequence Packing / THD input format** | ✅ Supported | | **FP8 with THD input format** | ✅ Supported where FP8 is supported | | **Import from HuggingFace checkpoints** | ✅ Supported | -| **Export to HuggingFace checkpoints** | 🚧 Under development | +| **Export to HuggingFace checkpoints** | ✅ Supported | See [BioNemo Recipes](../../recipes/README.md) for more details on how to use these features to accelerate model training and inference. @@ -77,17 +77,108 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory: Generate converted ESM-2 checkpoints from existing HuggingFace transformers checkpoints: ```bash -mkdir -p checkpoint_export +mkdir -p hf_to_te_checkpoint_export docker build -t esm2 . docker run --rm -it --gpus all \ - -v $PWD/checkpoint_export/:/workspace/bionemo/checkpoint_export \ + -v $PWD/hf_to_te_checkpoint_export/:/workspace/bionemo/hf_to_te_checkpoint_export \ -v $HOME/.cache/huggingface/:/root/.cache/huggingface \ - esm2 python export.py + esm2 python export.py hf-to-te ``` ### TE to HF Transformers conversion -(Coming soon) +```bash +MODEL_TAG=esm2_t6_8M_UR50D # specify which model to convert +mkdir -p te_to_hf_checkpoint_export +docker build -t esm2 . +docker run --rm -it --gpus all \ + -v $PWD/te_to_hf_checkpoint_export/:/workspace/bionemo/te_to_hf_checkpoint_export \ + -v $PWD/hf_to_te_checkpoint_export/$MODEL_TAG:/workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG \ + -v $HOME/.cache/huggingface/:/root/.cache/huggingface \ + esm2 python export.py te-to-hf --checkpoint-path /workspace/bionemo/hf_to_te_checkpoint_export/$MODEL_TAG +``` + +## Developer Conversion Workflow + +This section explains how to convert between Hugging Face and Transformer Engine (TE) ESM2 model formats. The process demonstrates bidirectional conversion: from Hugging Face to TE format for optimized inference, and back to Hugging Face format for sharing and deployment. The workflow involves several key steps: + +### Step 1: Load Original Hugging Face Model + +First, load the original ESM2 model from Hugging Face: + +```python +from transformers import AutoModelForMaskedLM + +model_hf_original = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") +``` + +This loads the pre-trained ESM2 model that will serve as our reference for comparison. + +### Step 2: Export to Transformer Engine Format + +Convert the Hugging Face model to Transformer Engine format using the high-level export API: + +```python +from pathlib import Path +from esm.export import export_hf_checkpoint + +te_checkpoint_path = Path("te_checkpoint") +export_hf_checkpoint("esm2_t6_8M_UR50D", te_checkpoint_path) +``` + +This creates a Transformer Engine checkpoint that can be used for optimized inference. + +### Step 3: Export Back to Hugging Face Format + +Convert the Transformer Engine checkpoint back to Hugging Face format: + +```python +from esm.export import export_te_checkpoint + +hf_export_path = Path("hf_export") +exported_model_path = te_checkpoint_path / "esm2_t6_8M_UR50D" +export_te_checkpoint(str(exported_model_path), hf_export_path) +``` + +This step creates a new Hugging Face model that should be functionally equivalent to the original. + +### Step 4: Load and Test the Exported Model + +Load the exported model and perform validation: + +```python +from transformers import AutoTokenizer +model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path)) +tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") +``` + +### Step 5: Validate Model Equivalence + +Test the exported model against the original using masked language modeling: + +```python +import torch +from transformers import DataCollatorForLanguageModeling + +# Prepare test sequence +sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" +batch = tokenizer([sequence], return_tensors="pt") +collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) +inputs = collator([{"input_ids": batch["input_ids"][0]}]) + +# Compare outputs +with torch.no_grad(): + outputs_original = model_hf_original(**inputs) + outputs_exported = model_hf_exported(**inputs) + +# Check differences +logits_diff = torch.abs(outputs_original.logits - outputs_exported.logits).max() +print(f"Max logits difference: {logits_diff:.2e}") + +if outputs_original.loss is not None and outputs_exported.loss is not None: + loss_diff = abs(outputs_original.loss - outputs_exported.loss) + print(f"Loss difference: {loss_diff:.2e}") +``` ## Developer Guide diff --git a/bionemo-recipes/models/esm2/export.py b/bionemo-recipes/models/esm2/export.py index 9dec188821..c4336bd9d6 100644 --- a/bionemo-recipes/models/esm2/export.py +++ b/bionemo-recipes/models/esm2/export.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse from pathlib import Path -from esm.export import export_hf_checkpoint +from esm.export import export_hf_checkpoint, export_te_checkpoint ESM_TAGS = [ @@ -30,10 +31,48 @@ def main(): """Export the ESM2 models from Hugging Face to the Transformer Engine format.""" - # TODO (peter): maybe add a way to specify the model to export or option to export all models? - for tag in ESM_TAGS: - print(f"Converting {tag}...") - export_hf_checkpoint(tag, Path("./checkpoint_export")) + parser = argparse.ArgumentParser(description="Convert ESM2 models from Hugging Face to Transformer Engine format") + + subparsers = parser.add_subparsers(dest="conversion_type", required=True, help="Type of conversion to perform") + + hf_to_te_parser = subparsers.add_parser("hf-to-te", help="Convert from HuggingFace to Transformer Engine format") + hf_to_te_parser.add_argument( + "--model", + type=str, + choices=ESM_TAGS, + help="Specific model tag to convert. If not provided, all models will be converted.", + ) + hf_to_te_parser.add_argument( + "--output-path", + type=str, + default="./hf_to_te_checkpoint_export", + help="Output directory path for the converted model. Defaults to './hf_to_te_checkpoint_export'", + ) + + te_to_hf_parser = subparsers.add_parser("te-to-hf", help="Convert from Transformer Engine to HuggingFace format") + te_to_hf_parser.add_argument( + "--checkpoint-path", type=str, required=True, help="Path to the Transformer Engine checkpoint to convert" + ) + te_to_hf_parser.add_argument( + "--output-path", + type=str, + default="./te_to_hf_checkpoint_export", + help="Output directory path for the converted model. Defaults to './te_to_hf_checkpoint_export'", + ) + + args = parser.parse_args() + + if args.conversion_type == "hf-to-te": + if args.model: + print(f"Converting {args.model} from HuggingFace to Transformer Engine format") + export_hf_checkpoint(args.model, Path(args.output_path)) + else: + for tag in ESM_TAGS: + print(f"Converting {tag} from HuggingFace to Transformer Engine format") + export_hf_checkpoint(tag, Path(args.output_path)) + else: + print(f"Converting {args.checkpoint_path} from Transformer Engine to HuggingFace format") + export_te_checkpoint(args.checkpoint_path, Path(args.output_path)) if __name__ == "__main__": diff --git a/bionemo-recipes/models/esm2/src/esm/convert.py b/bionemo-recipes/models/esm2/src/esm/convert.py index 29e07f17ac..2665d8bd4e 100644 --- a/bionemo-recipes/models/esm2/src/esm/convert.py +++ b/bionemo-recipes/models/esm2/src/esm/convert.py @@ -17,6 +17,7 @@ from accelerate import init_empty_weights from nemo.lightning import io from torch import nn +from transformers import EsmConfig, EsmForMaskedLM from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM @@ -40,6 +41,9 @@ "lm_head.layer_norm.bias": "lm_head.decoder.layer_norm_bias", } +# Reverse mapping from TE to HF format by reversing the original mapping +reverse_mapping = {v: k for k, v in mapping.items()} + def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module: """Convert a Hugging Face model to a Transformer Engine model. @@ -69,6 +73,70 @@ def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module: return output_model +def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: + """Convert a Transformer Engine model back to the original HuggingFace Facebook ESM-2 format. + + This function converts from the NVIDIA Transformer Engine (TE) format back to the + weight format compatible with the original facebook/esm2_* series of checkpoints. + The TE model is also a HuggingFace model, but this conversion ensures compatibility + with the original Facebook ESM-2 model architecture and weight format hosted on Hugging Face. + + Args: + model_te (nn.Module): The Transformer Engine model. + **config_kwargs: Additional configuration kwargs to be passed to EsmConfig. + + Returns: + nn.Module: The Hugging Face model in original Facebook ESM-2 format hosted on Hugging Face. + """ + # Convert TE config to HF config + hf_config_dict = model_te.config.to_dict() + + # Remove TE-specific config options + te_specific_keys = [ + "qkv_weight_interleaved", + "encoder_activation", + "attn_input_format", + "fuse_qkv_params", + "micro_batch_size", + "max_seq_length", + "model_type", + "auto_map", + ] + for key in te_specific_keys: + hf_config_dict.pop(key, None) + + hf_config_dict["model_type"] = "esm" + + hf_config = EsmConfig(**hf_config_dict, **config_kwargs) + + with init_empty_weights(): + model_hf = EsmForMaskedLM(hf_config) + + # Remove contact_head since it's not present in TE models + if hasattr(model_hf.esm, "contact_head"): + delattr(model_hf.esm, "contact_head") + + output_model = io.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [_unpack_qkv_weight, _unpack_qkv_bias, _unpad_embeddings, _unpad_decoder_weights, _unpad_bias], + state_dict_ignored_entries=[ + "lm_head.decoder.weight", + "esm.contact_head.regression.weight", + "esm.contact_head.regression.bias", + ], + ) + + output_model.tie_weights() + + # Note: contact_head parameters are not preserved in TE models + # They are lost during HF -> TE conversion and cannot be recovered + # The converted model will not have the original contact_head weights + + return output_model + + @io.state_transform( source_key=( "esm.encoder.layer.*.attention.self.query.weight", @@ -81,11 +149,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" concat_weights = torch.cat((query, key, value), dim=0) input_shape = concat_weights.size() - np = ctx.target.config.num_attention_heads + num_heads = ctx.target.config.num_attention_heads # transpose weights # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] - concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) + concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1]) concat_weights = concat_weights.transpose(0, 1).contiguous() concat_weights = concat_weights.view(*input_shape) return concat_weights @@ -103,16 +171,76 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" concat_biases = torch.cat((query, key, value), dim=0) input_shape = concat_biases.size() - np = ctx.target.config.num_attention_heads + num_heads = ctx.target.config.num_attention_heads # transpose biases # [num_splits_model_parallel * attention head size * #attention heads] # --> [attention head size * num_splits_model_parallel * #attention heads] - concat_biases = concat_biases.view(3, np, -1) + concat_biases = concat_biases.view(3, num_heads, -1) concat_biases = concat_biases.transpose(0, 1).contiguous() concat_biases = concat_biases.view(*input_shape) return concat_biases +@io.state_transform( + source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), +) +def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight): + """Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value.""" + num_heads = ctx.source.config.num_attention_heads + total_rows, input_dim = qkv_weight.size() # size: [num_heads * 3 *head_dim, input_dim] + assert total_rows % (3 * num_heads) == 0, ( + f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}" + ) + head_dim = total_rows // (3 * num_heads) + + qkv_weight = qkv_weight.view(num_heads, 3, head_dim, input_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim, input_dim] + query, key, value = qkv_weight[0], qkv_weight[1], qkv_weight[2] # size: [num_heads, head_dim, input_dim] + + query = query.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + key = key.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + value = value.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + + return query, key, value + + +@io.state_transform( + source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), +) +def _unpack_qkv_bias(ctx: io.TransformCTX, qkv_bias): + """Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value.""" + num_heads = ctx.source.config.num_attention_heads + total_size = qkv_bias.size(0) # size: [num_heads * 3 * head_dim] + assert total_size % (3 * num_heads) == 0, ( + f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}" + ) + head_dim = total_size // (3 * num_heads) + + qkv_bias = qkv_bias.view(num_heads, 3, head_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim] + query, key, value = qkv_bias[0], qkv_bias[1], qkv_bias[2] # size: [num_heads, head_dim] + + query = query.reshape(-1) # size: [num_heads * head_dim] + key = key.reshape(-1) # size: [num_heads * head_dim] + value = value.reshape(-1) # size: [num_heads * head_dim] + + return query, key, value + + +def _unpad_weights(ctx: io.TransformCTX, padded_embed): + """Remove padding from the embedding layer to get back to the original dimension.""" + target_embedding_dimension = ctx.target.config.vocab_size + return padded_embed[:target_embedding_dimension] + + def _pad_weights(ctx: io.TransformCTX, source_embed): """Pad the embedding layer to the new input dimension.""" target_embedding_dimension = ctx.target.config.padded_vocab_size @@ -134,6 +262,16 @@ def _pad_weights(ctx: io.TransformCTX, source_embed): target_key="lm_head.decoder.weight", )(_pad_weights) +_unpad_embeddings = io.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="esm.embeddings.word_embeddings.weight", +)(_unpad_weights) + +_unpad_decoder_weights = io.state_transform( + source_key="lm_head.decoder.weight", + target_key="lm_head.decoder.weight", +)(_unpad_weights) + @io.state_transform( source_key="lm_head.bias", @@ -148,3 +286,13 @@ def _pad_bias(ctx: io.TransformCTX, source_bias): ) output_bias[:hf_embedding_dimension] = source_bias return output_bias + + +@io.state_transform( + source_key="lm_head.decoder.bias", + target_key="lm_head.bias", +) +def _unpad_bias(ctx: io.TransformCTX, padded_bias): + """Remove padding from the bias to get back to the original dimension.""" + target_embedding_dimension = ctx.target.config.vocab_size + return padded_bias[:target_embedding_dimension] diff --git a/bionemo-recipes/models/esm2/src/esm/export.py b/bionemo-recipes/models/esm2/src/esm/export.py index 44f6ad4ea5..b078cbce55 100644 --- a/bionemo-recipes/models/esm2/src/esm/export.py +++ b/bionemo-recipes/models/esm2/src/esm/export.py @@ -22,7 +22,8 @@ from jinja2 import Template from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer -from esm.convert import convert_esm_hf_to_te +from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf +from esm.modeling_esm_te import NVEsmForMaskedLM BENCHMARK_RESULTS = { @@ -116,3 +117,47 @@ def export_hf_checkpoint(tag: str, export_path: Path): del model_te gc.collect() torch.cuda.empty_cache() + + +def export_te_checkpoint(te_checkpoint_path: str, output_path: Path): + """Export a Transformer Engine checkpoint back to the original HuggingFace Facebook ESM-2 format. + + This function converts from the NVIDIA Transformer Engine (TE) format back to the + weight format compatible with the original facebook/esm2_* series of checkpoints. + The TE model is also a HuggingFace model (you can load it with AutoModel.from_pretrained), + but this conversion ensures compatibility with the original Facebook ESM-2 model format hosted on Hugging Face. + + Args: + te_checkpoint_path (str): Path to the TE checkpoint + output_path (Path): Output path for the converted Facebook ESM-2 format model + """ + if not Path(te_checkpoint_path).exists(): + raise FileNotFoundError(f"TE checkpoint {te_checkpoint_path} not found") + + print(f"Converting {te_checkpoint_path} from TE format back to original HuggingFace Facebook ESM-2 format") + + # Load the TE model and convert to HF format + model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path) + model_hf = convert_esm_te_to_hf(model_te) + model_hf.save_pretrained(output_path) + + tokenizer_config_path = Path(te_checkpoint_path) / "tokenizer_config.json" + if tokenizer_config_path.exists(): + shutil.copy(tokenizer_config_path, output_path / "tokenizer_config.json") + + vocab_path = Path(te_checkpoint_path) / "vocab.txt" + if vocab_path.exists(): + shutil.copy(vocab_path, output_path / "vocab.txt") + + special_tokens_path = Path(te_checkpoint_path) / "special_tokens_map.json" + if special_tokens_path.exists(): + shutil.copy(special_tokens_path, output_path / "special_tokens_map.json") + + model_hf = AutoModelForMaskedLM.from_pretrained( + output_path, + dtype=torch.bfloat16, + trust_remote_code=False, + ) + del model_hf + gc.collect() + torch.cuda.empty_cache() diff --git a/bionemo-recipes/models/esm2/tests/test_convert.py b/bionemo-recipes/models/esm2/tests/test_convert.py new file mode 100644 index 0000000000..35236d5d97 --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_convert.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from transformers import AutoModelForMaskedLM + + +def test_convert_te_to_hf_roundtrip(): + """Test that converting HF -> TE -> HF produces the same model.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf_original = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") + + model_te = convert_esm_hf_to_te(model_hf_original) + model_hf_converted = convert_esm_te_to_hf(model_te) + + original_state_dict = model_hf_original.state_dict() + converted_state_dict = model_hf_converted.state_dict() + original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k} + converted_keys = set(converted_state_dict.keys()) + assert original_keys == converted_keys + + for key in original_state_dict.keys(): + if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key: + torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5) + + + +def test_qkv_unpacking(): + """Test that QKV unpacking works correctly.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + for i in range(model_hf.config.num_hidden_layers): + hf_query = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] + hf_key = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] + hf_value = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] + + converted_query = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] + converted_key = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] + converted_value = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] + + torch.testing.assert_close(hf_query, converted_query, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(hf_key, converted_key, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(hf_value, converted_value, atol=1e-5, rtol=1e-5) + + +def test_config_conversion(): + """Test that config conversion works correctly.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + original_config_dict = model_hf.config.to_dict() + converted_config_dict = model_hf_converted.config.to_dict() + + for key, value in original_config_dict.items(): + assert key in converted_config_dict, f"Config field '{key}' missing in converted model" + assert converted_config_dict[key] == value, f"Config field '{key}' differs: original={value}, converted={converted_config_dict[key]}" + + assert model_hf_converted.config.model_type == "esm" + + te_specific_fields = [ + 'qkv_weight_interleaved', + 'encoder_activation', + 'attn_input_format', + 'fuse_qkv_params', + 'micro_batch_size', + 'auto_map' + ] + for field in te_specific_fields: + assert not hasattr(model_hf_converted.config, field), f"TE-specific field '{field}' should not be present in converted model" + + +def test_padding_unpadding_operations(): + """Test that padding and unpadding operations work correctly for embeddings and decoder weights.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + # Test word embeddings + original_embeddings = model_hf.state_dict()["esm.embeddings.word_embeddings.weight"] + converted_embeddings = model_hf_converted.state_dict()["esm.embeddings.word_embeddings.weight"] + assert original_embeddings.shape == converted_embeddings.shape, f"Embedding shapes don't match: {original_embeddings.shape} vs {converted_embeddings.shape}" + torch.testing.assert_close(original_embeddings, converted_embeddings, atol=1e-5, rtol=1e-5) + + # Test decoder weights + original_decoder = model_hf.state_dict()["lm_head.decoder.weight"] + converted_decoder = model_hf_converted.state_dict()["lm_head.decoder.weight"] + assert original_decoder.shape == converted_decoder.shape, f"Decoder shapes don't match: {original_decoder.shape} vs {converted_decoder.shape}" + torch.testing.assert_close(original_decoder, converted_decoder, atol=1e-5, rtol=1e-5) + + # Test bias + original_bias = model_hf.state_dict()["lm_head.bias"] + converted_bias = model_hf_converted.state_dict()["lm_head.bias"] + assert original_bias.shape == converted_bias.shape, f"Bias shapes don't match: {original_bias.shape} vs {converted_bias.shape}" + torch.testing.assert_close(original_bias, converted_bias, atol=1e-5, rtol=1e-5) + + # Test that TE model has padded dimensions + te_embeddings = model_te.state_dict()["esm.embeddings.word_embeddings.weight"] + te_decoder = model_te.state_dict()["lm_head.decoder.weight"] + assert te_embeddings.shape[0] >= original_embeddings.shape[0], "TE embeddings should be padded" + assert te_decoder.shape[0] >= original_decoder.shape[0], "TE decoder should be padded" + + # The padded parts should be zeros (for embeddings) or min values (for bias) + if te_embeddings.shape[0] > original_embeddings.shape[0]: + padding_rows = te_embeddings[original_embeddings.shape[0]:] + torch.testing.assert_close(padding_rows, torch.zeros_like(padding_rows), atol=1e-6, rtol=1e-6) diff --git a/bionemo-recipes/models/esm2/tests/test_export.py b/bionemo-recipes/models/esm2/tests/test_export.py index dc773666f4..9af5cc8d6d 100644 --- a/bionemo-recipes/models/esm2/tests/test_export.py +++ b/bionemo-recipes/models/esm2/tests/test_export.py @@ -14,6 +14,13 @@ # limitations under the License. +from pathlib import Path +import pytest +import tempfile +import torch +from transformers import AutoModelForMaskedLM + + def test_export_hf_checkpoint(tmp_path): from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer @@ -63,3 +70,34 @@ def test_export_hf_checkpoint(tmp_path): assert "**Benchmark Score:** 0.37" in readme_contents, ( f"README.md does not contain the expected CASP14 score line: {readme_contents}" ) + +@pytest.mark.parametrize("model_name", ["esm2_t6_8M_UR50D"]) +def test_export_te_checkpoint_to_hf(model_name): + """Test the export function that saves TE checkpoint as HF format.""" + from esm.export import export_hf_checkpoint, export_te_checkpoint + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + model_hf_original = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_name}") + + # Use export_hf_checkpoint to create TE checkpoint + te_checkpoint_path = temp_path / "te_checkpoint" + export_hf_checkpoint(model_name, te_checkpoint_path) + te_model_path = te_checkpoint_path / model_name + + hf_export_path = temp_path / "hf_export" + export_te_checkpoint(str(te_model_path), hf_export_path) + + model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path)) + + original_state_dict = model_hf_original.state_dict() + exported_state_dict = model_hf_exported.state_dict() + + original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k} + exported_keys = {k for k in exported_state_dict.keys() if "contact_head" not in k} + assert original_keys == exported_keys + + for key in original_state_dict.keys(): + if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key: + torch.testing.assert_close(original_state_dict[key], exported_state_dict[key], atol=1e-5, rtol=1e-5) \ No newline at end of file