-
Notifications
You must be signed in to change notification settings - Fork 88
Add expanded THD unit tests #1214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Peter St. John <[email protected]>
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughAdds a new pytest fixture to toggle attention backends and replaces a single hard-coded test with multiple parameterized tests comparing THD and BSHD across backends. Tests now validate losses, logits, and backward/gradient behavior, with guarded xfails for known backend/hardware constraints. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant T as Tester
participant P as pytest
participant F as attn_impl fixture
participant E as Env/Backend Selector
participant THD as THD Model
participant BSHD as BSHD Model
participant C as Comparators
T->>P: run tests (param: test_proteins, backend)
P->>F: request attn_impl(backend)
F->>E: set env vars + internal backend flag
Note over E: Select "flash_attn" or "fused_attn"
P->>THD: build inputs, forward (loss/logits)
P->>BSHD: build inputs, forward (loss/logits)
THD-->>P: outputs/gradients
BSHD-->>P: outputs/gradients
P->>C: compare losses/logits
alt backward tests
P->>THD: backward()
P->>BSHD: backward()
C-->>P: compare gradients (with tolerance rules)
else known limitation
Note over P: xfail on specific hw/backend
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/tests/test_thd.py (1)
177-182
: Drop the duplicated forward call before backward.The second
model_bshd(**input_data_bshd)
/model_thd(**input_data_thd)
invocation (Lines 180-181) repeats the same work, doubling runtime and building an extra autograd graph you immediately discard. Please remove the redundant pair to keep the test lean.- bshd_outputs = model_bshd(**input_data_bshd) - thd_outputs = model_thd(**input_data_thd) - thd_outputs.loss.backward() bshd_outputs.loss.backward()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
bionemo-recipes/models/esm2/tests/test_thd.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
bionemo-recipes/models/esm2/tests/test_thd.py (3)
bionemo-recipes/models/esm2/tests/conftest.py (5)
te_model_checkpoint
(133-137)tokenizer
(45-46)test_proteins
(90-91)input_data_thd
(128-129)input_data
(123-124)bionemo-recipes/models/esm2/src/esm/collator.py (1)
DataCollatorWithFlattening
(229-278)bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
NVEsmForMaskedLM
(415-508)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: unit-tests (models/esm2)
- GitHub Check: pre-commit
if request.param == "flash_attn": | ||
os.environ["NVTE_FUSED_ATTN"] = "0" | ||
os.environ["NVTE_FLASH_ATTN"] = "1" | ||
_attention_backends["backend_selection_requires_update"] = True | ||
|
||
else: | ||
os.environ["NVTE_FUSED_ATTN"] = "1" | ||
os.environ["NVTE_FLASH_ATTN"] = "0" | ||
_attention_backends["backend_selection_requires_update"] = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use monkeypatch
to isolate backend env changes.
Line 44 currently writes directly to os.environ
, so whichever parametrized run executes last leaves its NVTE_* settings behind for the rest of the test session. That leaks state into unrelated tests and already tripped hardware-specific failures on L4. Please rely on the provided fixture and patch both the environment and _attention_backends
via monkeypatch
so pytest restores them automatically after each test.
@pytest.fixture(params=["flash_attn", "fused_attn"])
def attn_impl(request, monkeypatch):
- if request.param == "flash_attn":
- os.environ["NVTE_FUSED_ATTN"] = "0"
- os.environ["NVTE_FLASH_ATTN"] = "1"
- _attention_backends["backend_selection_requires_update"] = True
- else:
- os.environ["NVTE_FUSED_ATTN"] = "1"
- os.environ["NVTE_FLASH_ATTN"] = "0"
- _attention_backends["backend_selection_requires_update"] = True
-
- return request.param
+ if request.param == "flash_attn":
+ monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
+ monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
+ else:
+ monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
+ monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
+
+ monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True)
+ return request.param
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if request.param == "flash_attn": | |
os.environ["NVTE_FUSED_ATTN"] = "0" | |
os.environ["NVTE_FLASH_ATTN"] = "1" | |
_attention_backends["backend_selection_requires_update"] = True | |
else: | |
os.environ["NVTE_FUSED_ATTN"] = "1" | |
os.environ["NVTE_FLASH_ATTN"] = "0" | |
_attention_backends["backend_selection_requires_update"] = True | |
@pytest.fixture(params=["flash_attn", "fused_attn"]) | |
def attn_impl(request, monkeypatch): | |
if request.param == "flash_attn": | |
monkeypatch.setenv("NVTE_FUSED_ATTN", "0") | |
monkeypatch.setenv("NVTE_FLASH_ATTN", "1") | |
else: | |
monkeypatch.setenv("NVTE_FUSED_ATTN", "1") | |
monkeypatch.setenv("NVTE_FLASH_ATTN"], "0") | |
monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True) | |
return request.param |
🤖 Prompt for AI Agents
bionemo-recipes/models/esm2/tests/test_thd.py around lines 44 to 53: the test
writes directly to os.environ and mutates the module-level _attention_backends
dict, leaking state across parametrized runs; replace direct os.environ
assignments with monkeypatch.setenv for each NVTE_* var and replace the direct
dict mutation with monkeypatch.setitem(_attention_backends,
"backend_selection_requires_update", True) so pytest will restore environment
and dict state automatically after each test.
# sus | ||
torch.testing.assert_close(thd_word_embeddings_grad, bshd_word_embeddings_grad, atol=1e-3, rtol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, these tests all pass on H100s. Probably need to dig into this line a bit more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible we co-train THD vs non THD and literally run this test every iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we could try something like that; run it for a few iterations
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Peter St. John <[email protected]>
Will need to carefully xfail these on L4 test runners.
Summary by CodeRabbit