Skip to content

Commit 22c7739

Browse files
committed
Fix review comments
1 parent f2f2247 commit 22c7739

File tree

5 files changed

+63
-102
lines changed

5 files changed

+63
-102
lines changed

models/esm2/README.md

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ This loads the pre-trained ESM2 model that will serve as our reference for compa
3737
Convert the Hugging Face model to Transformer Engine format using the high-level export API:
3838

3939
```python
40+
from pathlib import Path
4041
from esm.export import export_hf_checkpoint
4142

4243
te_checkpoint_path = Path("te_checkpoint")
@@ -64,6 +65,7 @@ This step creates a new Hugging Face model that should be functionally equivalen
6465
Load the exported model and perform validation:
6566

6667
```python
68+
from transformers import AutoTokenizer
6769
model_hf_exported = AutoModelForMaskedLM.from_pretrained(str(hf_export_path))
6870
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
6971
```
@@ -73,21 +75,14 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
7375
Test the exported model against the original using masked language modeling:
7476

7577
```python
78+
import torch
79+
from transformers import DataCollatorForLanguageModeling
80+
7681
# Prepare test sequence
7782
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
78-
inputs = tokenizer(sequence, return_tensors="pt")
79-
80-
# Create masked inputs (15% masking)
81-
input_ids = inputs["input_ids"].clone()
82-
labels = inputs["input_ids"].clone()
83-
mask_token_id = tokenizer.mask_token_id
84-
85-
for i in range(input_ids.shape[1]):
86-
if torch.rand(1).item() < 0.15:
87-
input_ids[0, i] = mask_token_id
88-
89-
inputs["input_ids"] = input_ids
90-
inputs["labels"] = labels
83+
batch = tokenizer([sequence], return_tensors="pt")
84+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
85+
inputs = collator([{"input_ids": batch["input_ids"][0]}])
9186

9287
# Compare outputs
9388
with torch.no_grad():

models/esm2/export_te_checkpoint_to_hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ def main(te_checkpoint_dir: str, output_dir: str):
4444
parser.add_argument(
4545
"--model",
4646
type=str,
47+
required=True,
4748
help="Path to the TE checkpoint.",
4849
)
4950
parser.add_argument(
5051
"--output_path",
5152
type=str,
5253
default="./hf_checkpoints",
53-
help="Base output directory for the converted models. Each checkpoint will be saved in a subdirectory named after the checkpoint. If not provided, uses './hf_checkpoints'.",
54+
help="Output directory for the converted model. The model will be saved directly to this directory. If not provided, uses './hf_checkpoints'.",
5455
)
5556
args = parser.parse_args()
5657

models/esm2/src/esm/convert.py

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
103103
"micro_batch_size",
104104
"max_seq_length",
105105
"model_type",
106+
"auto_map",
106107
]
107108
for key in te_specific_keys:
108109
hf_config_dict.pop(key, None)
109110

111+
hf_config_dict["model_type"] = "esm"
112+
110113
hf_config = EsmConfig(**hf_config_dict, **config_kwargs)
111114

112115
with init_empty_weights():
@@ -149,11 +152,11 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
149152
"""Pad the embedding layer to the new input dimension."""
150153
concat_weights = torch.cat((query, key, value), dim=0)
151154
input_shape = concat_weights.size()
152-
np = ctx.target.config.num_attention_heads
155+
num_heads = ctx.target.config.num_attention_heads
153156
# transpose weights
154157
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
155158
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
156-
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
159+
concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1])
157160
concat_weights = concat_weights.transpose(0, 1).contiguous()
158161
concat_weights = concat_weights.view(*input_shape)
159162
return concat_weights
@@ -171,11 +174,11 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
171174
"""Pad the embedding layer to the new input dimension."""
172175
concat_biases = torch.cat((query, key, value), dim=0)
173176
input_shape = concat_biases.size()
174-
np = ctx.target.config.num_attention_heads
177+
num_heads = ctx.target.config.num_attention_heads
175178
# transpose biases
176179
# [num_splits_model_parallel * attention head size * #attention heads]
177180
# --> [attention head size * num_splits_model_parallel * #attention heads]
178-
concat_biases = concat_biases.view(3, np, -1)
181+
concat_biases = concat_biases.view(3, num_heads, -1)
179182
concat_biases = concat_biases.transpose(0, 1).contiguous()
180183
concat_biases = concat_biases.view(*input_shape)
181184
return concat_biases
@@ -190,26 +193,20 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
190193
),
191194
)
192195
def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):
193-
"""Unpack the fused QKV weight into separate query, key, and value weights."""
194-
np = ctx.source.config.num_attention_heads
195-
196-
# Reverse the packing transformation
197-
# First, reshape to separate the interleaved Q, K, V
198-
# [attention head size * num_splits_model_parallel * #attention heads]
199-
# --> [num_splits_model_parallel * attention head size * #attention heads]
200-
qkv_weight = qkv_weight.view(np, 3, -1, qkv_weight.size()[-1]) # Output:[num_heads, 3, head_dim, vocab_size]
201-
qkv_weight = qkv_weight.transpose(0, 1).contiguous() # Output:[3, num_heads, head_dim, vocab_size]
202-
203-
# Split into Q, K, V directly from the transposed tensor
204-
# qkv_weight shape: [3, num_heads, head_dim, input_dim]
205-
query = qkv_weight[0] # [num_heads, head_dim, input_dim]
206-
key = qkv_weight[1] # [num_heads, head_dim, input_dim]
207-
value = qkv_weight[2] # [num_heads, head_dim, input_dim]
208-
209-
# Reshape to match HF format: [total_head_dim, input_dim]
210-
query = query.view(-1, query.size()[-1]) # [num_heads * head_dim, input_dim]
211-
key = key.view(-1, key.size()[-1]) # [num_heads * head_dim, input_dim]
212-
value = value.view(-1, value.size()[-1]) # [num_heads * head_dim, input_dim]
196+
"""Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value."""
197+
num_heads = ctx.source.config.num_attention_heads
198+
total_rows, input_dim = qkv_weight.size() # size: [num_heads * 3 *head_dim, input_dim]
199+
assert total_rows % (3 * num_heads) == 0, (
200+
f"QKV weight rows {total_rows} not divisible by 3*num_heads {3*num_heads}"
201+
)
202+
head_dim = total_rows // (3 * num_heads)
203+
204+
qkv_weight = qkv_weight.view(num_heads, 3, head_dim, input_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim, input_dim]
205+
query, key, value = qkv_weight[0], qkv_weight[1], qkv_weight[2] # size: [num_heads, head_dim, input_dim]
206+
207+
query = query.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
208+
key = key.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
209+
value = value.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim]
213210

214211
return query, key, value
215212

@@ -223,25 +220,19 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):
223220
),
224221
)
225222
def _unpack_qkv_bias(ctx: io.TransformCTX, qkv_bias):
226-
"""Unpack the fused QKV bias into separate query, key, and value biases."""
227-
np = ctx.source.config.num_attention_heads
223+
"""Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value."""
224+
num_heads = ctx.source.config.num_attention_heads
225+
total_size = qkv_bias.size(0) # size: [num_heads * 3 * head_dim]
226+
assert total_size % (3 * num_heads) == 0, (
227+
f"QKV bias size {total_size} not divisible by 3*num_heads {3*num_heads}"
228+
)
229+
head_dim = total_size // (3 * num_heads)
228230

229-
# Reverse the packing transformation
230-
# First, reshape to separate the interleaved Q, K, V
231-
# [num_splits_model_parallel * attention head size * #attention heads]
232-
# --> [attention head size * num_splits_model_parallel * #attention heads]
233-
qkv_bias = qkv_bias.view(np, 3, -1)
234-
qkv_bias = qkv_bias.transpose(0, 1).contiguous()
235-
236-
# Split into Q, K, V directly from the transposed tensor
237-
# qkv_bias shape: [3, num_heads, head_dim]
238-
query = qkv_bias[0] # [num_heads, head_dim]
239-
key = qkv_bias[1] # [num_heads, head_dim]
240-
value = qkv_bias[2] # [num_heads, head_dim]
241-
242-
# Reshape to match HF format: [total_head_dim]
243-
query = query.view(-1) # [num_heads * head_dim]
244-
key = key.view(-1) # [num_heads * head_dim]
245-
value = value.view(-1) # [num_heads * head_dim]
231+
qkv_bias = qkv_bias.view(num_heads, 3, head_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim]
232+
query, key, value = qkv_bias[0], qkv_bias[1], qkv_bias[2] # size: [num_heads, head_dim]
233+
234+
query = query.reshape(-1) # size: [num_heads * head_dim]
235+
key = key.reshape(-1) # size: [num_heads * head_dim]
236+
value = value.reshape(-1) # size: [num_heads * head_dim]
246237

247238
return query, key, value

models/esm2/src/esm/export.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
9090

9191
print(f"Converting {te_checkpoint_path} from TE format back to original HuggingFace Facebook ESM-2 format")
9292

93-
# Load the TE model
93+
# Load the TE model and convert to HF format
9494
model_te = NVEsmForMaskedLM.from_pretrained(te_checkpoint_path)
95-
96-
# Convert TE model to HF format
9795
model_hf = convert_esm_te_to_hf(model_te)
9896
model_hf.save_pretrained(output_path)
9997

@@ -110,16 +108,6 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
110108
if vocab_path.exists():
111109
shutil.copy(vocab_path, Path(output_path) / "vocab.txt")
112110

113-
# Update config to remove TE-specific settings
114-
config_path = Path(output_path) / "config.json"
115-
if config_path.exists():
116-
with open(config_path, "r") as f:
117-
config = json.load(f)
118-
config.pop("auto_map", None)
119-
config["model_type"] = "esm"
120-
with open(config_path, "w") as f:
121-
json.dump(config, f, indent=2, sort_keys=True)
122-
123111
model_hf = AutoModelForMaskedLM.from_pretrained(
124112
output_path,
125113
torch_dtype=torch.bfloat16,

models/esm2/tests/test_convert_reverse.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,9 @@
1818

1919
import pytest
2020
import torch
21-
from torch import nn
2221
from transformers import AutoModelForMaskedLM
2322

2423

25-
def test_esm_model_has_all_te_layers():
26-
"""Test that the converted TE model doesn't contain vanilla PyTorch layers."""
27-
from esm.convert import convert_esm_hf_to_te
28-
29-
model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
30-
model_te = convert_esm_hf_to_te(model_hf)
31-
vanilla_layers_found = []
32-
for name, module in model_te.named_modules():
33-
if isinstance(module, nn.Linear):
34-
vanilla_layers_found.append(f"Linear layer found in {name}")
35-
if isinstance(module, nn.LayerNorm):
36-
vanilla_layers_found.append(f"LayerNorm layer found in {name}")
37-
if vanilla_layers_found:
38-
print("ERROR: Found vanilla PyTorch layers in converted TE model:")
39-
for error in vanilla_layers_found:
40-
print(f"WARNING: {error}")
41-
assert not vanilla_layers_found, f"Found {len(vanilla_layers_found)} vanilla layers in converted model"
42-
43-
4424
def test_convert_te_to_hf_roundtrip():
4525
"""Test that converting HF -> TE -> HF produces the same model."""
4626
from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf
@@ -124,16 +104,22 @@ def test_config_conversion():
124104
model_te = convert_esm_hf_to_te(model_hf)
125105
model_hf_converted = convert_esm_te_to_hf(model_te)
126106

107+
original_config_dict = model_hf.config.to_dict()
108+
converted_config_dict = model_hf_converted.config.to_dict()
109+
110+
for key, value in original_config_dict.items():
111+
assert key in converted_config_dict, f"Config field '{key}' missing in converted model"
112+
assert converted_config_dict[key] == value, f"Config field '{key}' differs: original={value}, converted={converted_config_dict[key]}"
113+
127114
assert model_hf_converted.config.model_type == "esm"
128-
assert model_hf_converted.config.hidden_size == model_hf.config.hidden_size
129-
assert model_hf_converted.config.num_hidden_layers == model_hf.config.num_hidden_layers
130-
assert model_hf_converted.config.num_attention_heads == model_hf.config.num_attention_heads
131-
assert model_hf_converted.config.intermediate_size == model_hf.config.intermediate_size
132-
assert model_hf_converted.config.vocab_size == model_hf.config.vocab_size
133-
134-
# assert not hasattr(model_hf_converted.config, 'qkv_weight_interleaved')
135-
# assert not hasattr(model_hf_converted.config, 'encoder_activation')
136-
# assert not hasattr(model_hf_converted.config, 'attn_input_format')
137-
# assert not hasattr(model_hf_converted.config, 'fuse_qkv_params')
138-
# assert not hasattr(model_hf_converted.config, 'micro_batch_size')
139-
# assert not hasattr(model_hf_converted.config, 'max_seq_length')
115+
116+
te_specific_fields = [
117+
'qkv_weight_interleaved',
118+
'encoder_activation',
119+
'attn_input_format',
120+
'fuse_qkv_params',
121+
'micro_batch_size',
122+
'auto_map'
123+
]
124+
for field in te_specific_fields:
125+
assert not hasattr(model_hf_converted.config, field), f"TE-specific field '{field}' should not be present in converted model"

0 commit comments

Comments
 (0)