Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 129 additions & 27 deletions bionemo-recipes/models/esm2/tests/test_thd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from torch.optim import AdamW
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformers import DataCollatorForTokenClassification

from esm.collator import DataCollatorWithFlattening
Expand All @@ -29,29 +34,31 @@ def test_thd_from_collator_output(te_model_checkpoint, input_data_thd):

assert outputs.loss < 3.0

# if torch.cuda.get_device_capability() == (12, 0):
# # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# # but it's missing this THD implementation.
# monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_thd_values_match(te_model_checkpoint, tokenizer):
# Manually masked input tokens so that both BSHD and THD models have the same mask pattern

proteins = [
"MLSATEKLSDYISSLFASVSIINSISTEDLFFLKLTCQTFSKDSEEYKAAYRILRGVQRGKVQIIEEALVS",
"MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQDASKYLEYAIDSV",
"MAATGSLILSDEEQAELIALAVRIVLACAGGSQNKELAAQLGVIETTVGEWRRRFAQNRVEGLRDEARPGAPSDDQ",
"MSAVLSAVASDDWTAFAKLVHPYVHWTADGITTRGRTRVMARLSGHDGVKPASSYELRDGQVYRWTS",
"MSDPAAEPPADTSGIAWRKSSYSGPNGNCVELAQISGDHVGIRNSRDLHGSVLTCTRAEFAALLCDIKAGRFDSLIL",
"MRRPKLRRSGVLMSHPARGQPIKDASTEAAAERRPHVTSSERQDVSDQDTR",
"MQTITVAGGNLFQIAAQYLGDATQWIRIAQLNGLADPVLSGVVTLTIPQPNPLAGGGVVGQ",
"MVFSLEQFVRGQGWQSITSNSDNEVPKPRQVYEVKAVCHPGAWRVKARVFGTSQGIPFDYSQASMERRVAQDECDRRPQ",
"AGDGTGCNPTLSKAAGVELDNSDSGEVFVIYLHIIIAIIVLISINLIGFLYF",
"MKVGVDPSVCEAHGACMSILPEVFDLDDDEVLQIRDGELAPSEEESAERAVASCPMGALRLSR",
"MWISERPPSRMALGSQSQMSLPGIPARCLHS",
"MIDNSIRLFDADDSELFSLAEVPLDNKPIQRDTDSLSQWGDTWLREIQHS",
"MVKNLFFNKIKNATLKVANISRCYLPFPPPPCPPPEPLEPPEPPAPLEPAPDPPPLPPFPVPDILPAI",
"MSYINDITQSNSSILNVNVKINDHNSDEMYRNETKWYGEQFRYQSNPRFSRSSTSKNEKGFVQKKT",
"MQILILPIPDQLQNPNKISQHLICITFVSEQTLPI",
]
@pytest.fixture(params=["flash_attn", "fused_attn"])
def attn_impl(request, monkeypatch):
if request.param == "flash_attn":
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

else:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True

Comment on lines +45 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use monkeypatch to isolate backend env changes.

Line 44 currently writes directly to os.environ, so whichever parametrized run executes last leaves its NVTE_* settings behind for the rest of the test session. That leaks state into unrelated tests and already tripped hardware-specific failures on L4. Please rely on the provided fixture and patch both the environment and _attention_backends via monkeypatch so pytest restores them automatically after each test.

 @pytest.fixture(params=["flash_attn", "fused_attn"])
 def attn_impl(request, monkeypatch):
-    if request.param == "flash_attn":
-        os.environ["NVTE_FUSED_ATTN"] = "0"
-        os.environ["NVTE_FLASH_ATTN"] = "1"
-        _attention_backends["backend_selection_requires_update"] = True
-    else:
-        os.environ["NVTE_FUSED_ATTN"] = "1"
-        os.environ["NVTE_FLASH_ATTN"] = "0"
-        _attention_backends["backend_selection_requires_update"] = True
-
-    return request.param
+    if request.param == "flash_attn":
+        monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
+        monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
+    else:
+        monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
+        monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
+
+    monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True)
+    return request.param
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if request.param == "flash_attn":
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
else:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
@pytest.fixture(params=["flash_attn", "fused_attn"])
def attn_impl(request, monkeypatch):
if request.param == "flash_attn":
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
else:
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
monkeypatch.setenv("NVTE_FLASH_ATTN"], "0")
monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True)
return request.param
🤖 Prompt for AI Agents
bionemo-recipes/models/esm2/tests/test_thd.py around lines 44 to 53: the test
writes directly to os.environ and mutates the module-level _attention_backends
dict, leaking state across parametrized runs; replace direct os.environ
assignments with monkeypatch.setenv for each NVTE_* var and replace the direct
dict mutation with monkeypatch.setitem(_attention_backends,
"backend_selection_requires_update", True) so pytest will restore environment
and dict state automatically after each test.

return request.param

sequences = [tokenizer(protein) for protein in proteins]

def test_thd_losses_match(te_model_checkpoint, tokenizer, test_proteins, attn_impl):
# Manually masked input tokens so that both BSHD and THD models have the same mask pattern

sequences = [tokenizer(protein) for protein in test_proteins]
sequences = [
{
"input_ids": seq["input_ids"],
Expand Down Expand Up @@ -90,19 +97,114 @@ def test_thd_values_match(te_model_checkpoint, tokenizer):

torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss)

# bshd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)]
# TODO(BIONEMO-2801): Investigate why these are not close on sm89 but pass on sm120.
# torch.testing.assert_close(bshd_logits, thd_outputs.logits)
bshd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)]
torch.testing.assert_close(bshd_logits, thd_outputs.logits)


def test_thd_logits_match(te_model_checkpoint, tokenizer, test_proteins, attn_impl):
sequences = [tokenizer(protein) for protein in test_proteins]
sequences = [
{
"input_ids": seq["input_ids"],
"attention_mask": seq["attention_mask"],
"labels": seq["input_ids"],
}
for seq in sequences
]

bshd_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
thd_collator = DataCollatorWithFlattening()

input_data_bshd = bshd_collator(sequences)
input_data_thd = thd_collator(sequences)

model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)

model_bshd.to("cuda")
model_thd.to("cuda")

input_data_bshd = {k: v.to("cuda") for k, v in input_data_bshd.items()}
input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}

bshd_outputs = model_bshd(**input_data_bshd)
thd_outputs = model_thd(**input_data_thd)

# TODO(BIONEMO-2801): Investigate why these are not close on sm89 but pass on sm120.
bshd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)]
torch.testing.assert_close(bshd_logits, thd_outputs.logits)


def test_thd_backwards(te_model_checkpoint, input_data_thd, monkeypatch):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
def test_thd_backwards_works(te_model_checkpoint, input_data_thd, attn_impl):
if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0):
pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.")

model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)
model_thd.to("cuda")
input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}
outputs = model_thd(**input_data)
outputs.loss.backward()


def test_thd_backwards_passes_match(te_model_checkpoint, tokenizer, test_proteins, attn_impl):
if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0):
pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.")

sequences = [tokenizer(protein) for protein in test_proteins]
sequences = [
{
"input_ids": seq["input_ids"],
"attention_mask": seq["attention_mask"],
"labels": seq["input_ids"],
}
for seq in sequences
]

bshd_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
thd_collator = DataCollatorWithFlattening()

input_data_bshd = bshd_collator(sequences)
input_data_thd = thd_collator(sequences)

model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)
model_bshd.to("cuda")
model_thd.to("cuda")

input_data_bshd = {k: v.to("cuda") for k, v in input_data_bshd.items()}
input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}

optimizer_thd = AdamW(model_thd.parameters(), lr=1e-4)
optimizer_bshd = AdamW(model_bshd.parameters(), lr=1e-4)

for i in range(2):
print(f"Iteration {i}")

thd_outputs = model_thd(**input_data_thd)
bshd_outputs = model_bshd(**input_data_bshd)

total_norm_thd = torch.nn.utils.clip_grad_norm_(model_thd.parameters(), max_norm=1.0).item()
total_norm_bshd = torch.nn.utils.clip_grad_norm_(model_bshd.parameters(), max_norm=1.0).item()

torch.testing.assert_close(total_norm_thd, total_norm_bshd)

thd_outputs.loss.backward()
bshd_outputs.loss.backward()

thd_grads = {name: p.grad for name, p in model_thd.named_parameters() if p.grad is not None}
bshd_grads = {name: p.grad for name, p in model_bshd.named_parameters() if p.grad is not None}

max_abs_diff = {name: (bshd_grads[name] - thd_grads[name]).abs().max().item() for name in thd_grads.keys()}

# For some reason, the word embeddings grads have a slightly higher numerical error.
thd_word_embeddings_grad = thd_grads.pop("esm.embeddings.word_embeddings.weight")
bshd_word_embeddings_grad = bshd_grads.pop("esm.embeddings.word_embeddings.weight")

breakpoint()
torch.testing.assert_close(thd_grads, bshd_grads)
# sus
# torch.testing.assert_close(thd_word_embeddings_grad, bshd_word_embeddings_grad, atol=1e-3, rtol=1e-5)

optimizer_thd.step()
optimizer_bshd.step()