Skip to content

Conversation

pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Oct 1, 2025

Will need to carefully xfail these on L4 test runners.

Summary by CodeRabbit

  • Tests
    • Expanded test coverage to validate behavior across multiple attention backends.
    • Added checks comparing losses, logits, and gradients between implementations for consistency.
    • Replaced hard-coded samples with parameterized inputs to broaden scenario coverage.
    • Added safeguards for known hardware-specific limitations to prevent false failures.
    • Improved reliability and maintainability of model validation through unified input preparation and backend toggling.

Signed-off-by: Peter St. John <[email protected]>
Copy link
Contributor

coderabbitai bot commented Oct 1, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

Adds a new pytest fixture to toggle attention backends and replaces a single hard-coded test with multiple parameterized tests comparing THD and BSHD across backends. Tests now validate losses, logits, and backward/gradient behavior, with guarded xfails for known backend/hardware constraints.

Changes

Cohort / File(s) Summary
ESM2 THD tests
bionemo-recipes/models/esm2/tests/test_thd.py
Introduces attn_impl fixture to switch between "flash_attn" and "fused_attn" via env/internal flag; replaces prior single test/data block with parameterized tests: losses match, logits match, backward works, and backward gradients match; adds conditional xfails for known limitations; removes hard-coded protein list in favor of test_proteins input.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant T as Tester
    participant P as pytest
    participant F as attn_impl fixture
    participant E as Env/Backend Selector
    participant THD as THD Model
    participant BSHD as BSHD Model
    participant C as Comparators

    T->>P: run tests (param: test_proteins, backend)
    P->>F: request attn_impl(backend)
    F->>E: set env vars + internal backend flag
    Note over E: Select "flash_attn" or "fused_attn"

    P->>THD: build inputs, forward (loss/logits)
    P->>BSHD: build inputs, forward (loss/logits)

    THD-->>P: outputs/gradients
    BSHD-->>P: outputs/gradients

    P->>C: compare losses/logits
    alt backward tests
        P->>THD: backward()
        P->>BSHD: backward()
        C-->>P: compare gradients (with tolerance rules)
    else known limitation
        Note over P: xfail on specific hw/backend
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I twitch my ears at toggled skies,
Flash or fused, the backend tries—
Losses rhyme, logits chime,
Gradients dance in perfect time.
Xfails burrow, known and neat,
Tests hop on with thumping feet.
Carrots for coverage, oh so sweet!

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The provided description consists of a single note about xfail behavior and does not follow the repository’s required template, as it lacks detailed sections such as a comprehensive description of changes, usage examples, type of change categorization, CI configuration instructions, and the pre-submit checklist. Please update the pull request description to follow the repository’s template by adding a detailed explanation of the changes, usage examples, selecting the appropriate type of change, specifying any CI pipeline labels required, and completing the pre-submit checklist items.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly and accurately describes the primary change by indicating that the pull request adds expanded unit tests for the THD functionality, which aligns with the bulk of new and enhanced test code introduced in the changeset. It is clear, specific, and free of unnecessary detail or generic phrasing.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
bionemo-recipes/models/esm2/tests/test_thd.py (1)

177-182: Drop the duplicated forward call before backward.

The second model_bshd(**input_data_bshd) / model_thd(**input_data_thd) invocation (Lines 180-181) repeats the same work, doubling runtime and building an extra autograd graph you immediately discard. Please remove the redundant pair to keep the test lean.

-    bshd_outputs = model_bshd(**input_data_bshd)
-    thd_outputs = model_thd(**input_data_thd)
-
     thd_outputs.loss.backward()
     bshd_outputs.loss.backward()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c9dd049 and 4af6698.

📒 Files selected for processing (1)
  • bionemo-recipes/models/esm2/tests/test_thd.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
bionemo-recipes/models/esm2/tests/test_thd.py (3)
bionemo-recipes/models/esm2/tests/conftest.py (5)
  • te_model_checkpoint (133-137)
  • tokenizer (45-46)
  • test_proteins (90-91)
  • input_data_thd (128-129)
  • input_data (123-124)
bionemo-recipes/models/esm2/src/esm/collator.py (1)
  • DataCollatorWithFlattening (229-278)
bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py (1)
  • NVEsmForMaskedLM (415-508)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: unit-tests (models/esm2)
  • GitHub Check: pre-commit

Comment on lines +44 to +53
if request.param == "flash_attn":
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

else:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use monkeypatch to isolate backend env changes.

Line 44 currently writes directly to os.environ, so whichever parametrized run executes last leaves its NVTE_* settings behind for the rest of the test session. That leaks state into unrelated tests and already tripped hardware-specific failures on L4. Please rely on the provided fixture and patch both the environment and _attention_backends via monkeypatch so pytest restores them automatically after each test.

 @pytest.fixture(params=["flash_attn", "fused_attn"])
 def attn_impl(request, monkeypatch):
-    if request.param == "flash_attn":
-        os.environ["NVTE_FUSED_ATTN"] = "0"
-        os.environ["NVTE_FLASH_ATTN"] = "1"
-        _attention_backends["backend_selection_requires_update"] = True
-    else:
-        os.environ["NVTE_FUSED_ATTN"] = "1"
-        os.environ["NVTE_FLASH_ATTN"] = "0"
-        _attention_backends["backend_selection_requires_update"] = True
-
-    return request.param
+    if request.param == "flash_attn":
+        monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
+        monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
+    else:
+        monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
+        monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
+
+    monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True)
+    return request.param
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if request.param == "flash_attn":
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
else:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
@pytest.fixture(params=["flash_attn", "fused_attn"])
def attn_impl(request, monkeypatch):
if request.param == "flash_attn":
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
else:
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
monkeypatch.setenv("NVTE_FLASH_ATTN"], "0")
monkeypatch.setitem(_attention_backends, "backend_selection_requires_update", True)
return request.param
🤖 Prompt for AI Agents
bionemo-recipes/models/esm2/tests/test_thd.py around lines 44 to 53: the test
writes directly to os.environ and mutates the module-level _attention_backends
dict, leaking state across parametrized runs; replace direct os.environ
assignments with monkeypatch.setenv for each NVTE_* var and replace the direct
dict mutation with monkeypatch.setitem(_attention_backends,
"backend_selection_requires_update", True) so pytest will restore environment
and dict state automatically after each test.

Comment on lines 195 to 196
# sus
torch.testing.assert_close(thd_word_embeddings_grad, bshd_word_embeddings_grad, atol=1e-3, rtol=1e-5)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, these tests all pass on H100s. Probably need to dig into this line a bit more

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible we co-train THD vs non THD and literally run this test every iteration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we could try something like that; run it for a few iterations

@pstjohn pstjohn marked this pull request as draft October 2, 2025 00:57
Copy link

copy-pr-bot bot commented Oct 2, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Signed-off-by: Peter St. John <[email protected]>
@pstjohn pstjohn closed this Oct 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants