Skip to content

Conversation

@yzhang123
Copy link
Collaborator

@yzhang123 yzhang123 commented Sep 3, 2025

Description

make evo2 gpt model compatible with predict.py and infer.py

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels:

Note

By default, the notebooks validation tests are skipped unless explicitly enabled.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Usage

# TODO: Add code snippet

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

  • New Features

    • Added GPT support for training, inference and prediction with Evo2-style configs, extended RoPE/context controls, targeted variance loss and loss reweighting.
    • New model config options and selectable GPT model sizes.
    • Introduced Mamba predictor with optional per-sequence log-prob outputs and packed-sequence support.
  • Chores

    • CLI: simpler boolean flags (--fp8, --flash-decode) and dynamic inference sizing based on prompt length + max_new_tokens.
  • Bug Fixes

    • Clearer runtime checks and messaging when Transformer Engine is required.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 3, 2025

Walkthrough

Adds an Evo2 GPT integration with new model/config classes and forward helpers; updates Mamba forward for Transformer Engine and packed sequences; wires GPT support into predict/train/infer CLIs and introduces MambaPredictor and dynamic inference sizing; adds RoPE and embedding configuration knobs and weight-decay helpers.

Changes

Cohort / File(s) Summary
Evo2 GPT model & config
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py
Adds evo2_gpt_forward_step, Evo2GPTModel (with inference wrapper/get_inference_wrapper and forward), Evo2StyleMCoreGPTModel (custom loss/reweighting), LLama31ConfigEvoLoss3B (RoPE/context/embedding fields and configure_model), weight-decay helpers, and GPT_MODEL_OPTIONS export.
Mamba forward TE & packed-seq
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py
Adds safe_import TE detection (HAVE_TE), conditions attention_mask on TE availability, and supports packed_seq_params via get_packed_seq_params(batch) when cu_seqlens present.
Inference CLI/runtime
sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py
Changes --fp8 and --flash-decode to store_true flags and computes inference thresholds (inference_batch_times_seqlen_threshold, inference_max_seq_length) from len(prompt) + max_new_tokens instead of fixed 8192.
Predict flow & MambaPredictor
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Adds GPT model-type support and validation via GPT_MODEL_OPTIONS, integrates GPT config path, introduces MambaPredictor class (predict_step, forward, log-prob handling), and enforces invalid-model-type errors.
Train flow, RoPE & embedding options
sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py
Adds GPT training path using GPT_MODEL_OPTIONS and Evo2GPTModel; new CLI args for embedding and RoPE (--embedding-init-std, --old-context-len, --scale-factor, --rope-base); embedding init controls, old_context_len/scale_factor logic, logging/run-name updates, and module-level logger.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant CLI as train.py CLI
  participant Builder as Config Builder
  participant Options as GPT_MODEL_OPTIONS
  participant Model as Evo2GPTModel
  participant Core as Evo2StyleMCoreGPTModel

  User->>CLI: run train --model-size <size> [GPT/RoPE/emb args]
  CLI->>Builder: parse args, detect model_type
  alt model_size in GPT options
    Builder->>Options: resolve config class
    Options-->>Builder: LLama31ConfigEvoLoss3B
    Builder->>Builder: compute old_context_len, scale_factor, rotary_base, embedding_init_method_std
    Builder->>Model: instantiate with tokenizer and config
    Model->>Core: configure_model(...)
    Core-->>Model: constructed core GPT model
  else other model types
    Builder->>Builder: follow existing paths
  end
Loading
sequenceDiagram
  autonumber
  participant Pred as predict.py
  participant Opts as GPT_MODEL_OPTIONS
  participant MPred as MambaPredictor
  participant HPred as HyenaPredictor
  participant Run as infer.py

  Pred->>Pred: parse args model_type
  alt model_type == "gpt"
    Pred->>Opts: validate model_size
    Pred->>HPred: init with GPT-configured settings
  else model_type == "mamba"
    Pred->>MPred: init MambaPredictor
  else invalid
    Pred-->>Pred: raise error
  end
  Run->>Run: compute thresholds = len(prompt) + max_new_tokens
  Run-->>Pred: inference executes with dynamic thresholds
Loading
sequenceDiagram
  autonumber
  participant Step as evo2_gpt_forward_step
  participant Batch as batch
  participant TE as TransformerEngine check
  participant GPT as Evo2GPTModel/Evo2StyleMCoreGPTModel

  Step->>Batch: read tokens, position_ids, labels, loss_mask
  alt attention_mask in batch
    Step->>Step: use batch.attention_mask
  else
    Step->>TE: assert Transformer Engine available
  end
  opt cu_seqlens in batch
    Step->>Step: packed_seq_params = get_packed_seq_params(batch)
  end
  Step->>GPT: forward(input_ids, position_ids, [attention_mask], labels, loss_mask, [packed_seq_params])
  GPT-->>Step: logits or loss
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • jstjohn
  • pstjohn
  • jwilber
  • cspades
  • dorotat-nv
  • trvachov
  • jomitchellnv
  • skothenhill-nv
  • sichu2023
  • broland-hat
  • polinabinder1

Poem

I hop through code with nimble paws,
New GPT paths and careful laws.
RoPE stretched wide, TE checks in place,
Packed sequences race through space.
Carrots of configs, stitched with care—deploy with grace. 🥕

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch evo2_gpt_inference

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)

60-66: Fix grammatical error in assertion message.

There's a grammatical error in the assertion message.

-        assert HAVE_TE, (
-            "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \
-            This requires Transformer Engine's implementation of fused or flash attention."
-        )
+        assert HAVE_TE, (
+            "The dataloader did not provide an attention mask; however, Transformer Engine was not detected. "
+            "This requires Transformer Engine's implementation of fused or flash attention."
+        )
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (2)

521-628: Consider extract common prediction logic to reduce duplication.

The MambaPredictor class shares significant logic with HyenaPredictor (e.g., log probability calculations, tensor gathering). Consider extracting the common prediction functionality into a base class or utility functions to improve maintainability.

Would you like me to help refactor this to create a shared base class for common predictor functionality?


473-479: Document why HyenaPredictor is used for GPT models.

Using HyenaPredictor for GPT models might be confusing. Add a comment explaining why this reuse is appropriate, or consider creating a dedicated GPTPredictor class if there are GPT-specific requirements.

     elif model_type == "gpt":
+        # GPT models can reuse HyenaPredictor since they share the same prediction interface
+        # and forward step logic. The model configuration handles the architectural differences.
         model = HyenaPredictor(
             config,
             tokenizer=tokenizer,
             output_log_prob_seqs=output_log_prob_seqs,
             log_prob_collapse_option=log_prob_collapse_option,
         )
sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (2)

693-706: Improve clarity of old context length determination.

The logic for determining old_context_len is complex with multiple conditionals. Consider refactoring for clarity and ensure the warning message provides actionable guidance.

-        if args.ckpt_dir is None or args.old_context_len:
-            # Set the old context length based on the initial pre-training run seq_length
-            #  when you supply a ckpt_dir, assume that we will use whatever that value was set to previously
-            #  for rope extension.
-            old_context_len = args.old_context_len or args.seq_length  # set to the seq_length if not supplied
-        else:
-            if not args.old_context_len:
-                old_context_len = args.seq_length
-                logger.warning(
-                    "No old context length supplied, using the seq_length as the old context length. "
-                    "This is not recommended and if training at a different context length the RoPE scaling factors "
-                    "will be incorrect. Please supply the old context length when fine-tuning especially if you are "
-                    "extending the context length."
-                )
-            else:
-                old_context_len = args.old_context_len
+        # Determine old context length for RoPE scaling
+        if args.old_context_len:
+            old_context_len = args.old_context_len
+        elif args.ckpt_dir is None:
+            # New training run: use current seq_length as baseline
+            old_context_len = args.seq_length
+        else:
+            # Fine-tuning from checkpoint without explicit old_context_len
+            old_context_len = args.seq_length
+            logger.warning(
+                f"Fine-tuning from checkpoint without --old-context-len specified. "
+                f"Defaulting to current seq_length ({args.seq_length}). "
+                f"If the original model was trained with a different context length, "
+                f"please specify it using --old-context-len for correct RoPE scaling."
+            )

685-687: Add informative error for GPT without required flag.

The assertion provides a clear requirement but could be more helpful by suggesting the fix directly.

-        assert args.no_fp32_residual_connection, (
-            "GPT models do not support fp32 residual connection, please run with --no-fp32-residual-connection."
-        )
+        assert args.no_fp32_residual_connection, (
+            "GPT models require --no-fp32-residual-connection flag to be set. "
+            "Please add this flag to your command."
+        )
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)

380-381: Consider implementing TE full layer spec support.

The code raises an error for use_transformer_engine_full_layer_spec. If this is a planned feature, consider adding a TODO comment or tracking issue reference.

Consider updating the error message to be more informative:

         if self.use_transformer_engine_full_layer_spec:
-            raise ValueError("use_transformer_engine_full_layer_spec is not supported in this config.")
+            raise NotImplementedError(
+                "use_transformer_engine_full_layer_spec is not yet supported for Evo2 GPT models. "
+                "Please use the standard transformer layer spec instead."
+            )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4a47964 and 52dacd6.

📒 Files selected for processing (5)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1 hunks)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (2 hunks)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py (1 hunks)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (5 hunks)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (3)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
  • get_packed_seq_params (227-262)
sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss/embedding_variance.py (1)
  • SquaredErrorTargetedVarianceLoss (169-192)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (4)
  • get_inference_wrapper (81-113)
  • forward (116-146)
  • forward (171-275)
  • configure_model (368-395)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
  • get_packed_seq_params (227-262)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (2)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (2)
  • Evo2GPTModel (83-153)
  • gpt_no_weight_decay_cond_with_embeddings (224-229)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (2)
  • mamba_no_weight_decay_cond_with_embeddings (301-306)
  • MambaModel (74-146)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-bionemo-image
  • GitHub Check: Analyze (rust)
🔇 Additional comments (10)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py (2)

88-98: LGTM! CLI improvements enhance usability.

The change from requiring explicit boolean values to using store_true makes the CLI more intuitive and consistent with standard argparse conventions.


174-175: Verify that dynamic threshold calculations don't cause memory issues.

The dynamic calculation of inference_batch_times_seqlen_threshold and inference_max_seq_length based on len(prompt) + max_new_tokens is more flexible than the previous fixed value of 8192. However, this could potentially lead to very large memory allocations for long prompts or large max_new_tokens values.

Are there any safeguards in the underlying inference wrapper to prevent OOM errors when these thresholds are set too high? Consider adding a maximum cap or warning when the calculated value exceeds a reasonable threshold (e.g., 32768).

sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)

68-69: LGTM! Proper handling of packed sequences.

The addition of packed sequence support through get_packed_seq_params is well-implemented and follows the established pattern from other parts of the codebase.

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (1)

630-641: LGTM! Well-structured embedding initialization logic.

The handling of embedding initialization parameters is comprehensive with clear precedence rules and appropriate warnings for conflicting options.

sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (6)

53-81: LGTM! Clean forward step implementation.

The forward step function properly handles batch processing with optional attention masks and packed sequences. The TE requirement check for missing attention masks is appropriate.


163-176: Good implementation of targeted variance loss initialization.

The conditional initialization of the targeted variance loss module with proper warning for missing config attributes is well implemented. The use of the Spike No More paper's suggested value (1.0) as default is appropriate.


177-205: Well-structured forward pass with proper loss handling.

The forward method correctly handles different execution paths (inference vs training, pipeline parallelism) and properly applies the reweighted cross-entropy loss and optional variance loss only during training.


265-279: Good backward compatibility handling in post-init.

The post-init method properly handles backward compatibility for the embedding_init_method_std attribute with clear warning messages and sensible defaults.


280-391: Comprehensive model configuration with proper RoPE scaling.

The configure_model method is well-implemented with:

  • Proper CUDA graph validation
  • Virtual pipeline parallelism handling
  • Flexible transformer layer spec resolution
  • Correct RoPE scaling application for Llama3.1

393-397: Model registry properly configured.

The GPT_MODEL_OPTIONS dictionary correctly registers the Llama3 8B configuration for use in the model selection system.

Comment on lines +252 to +253
hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general
forward_step_fn: Callable = evo2_gpt_forward_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Rename the misleadingly named configuration field.

The field hyena_no_weight_decay_cond_fn should be renamed to something more general as indicated by the TODO comment, since this is not specific to Hyena models.

Consider renaming this field to no_weight_decay_cond_fn or weight_decay_cond_fn for clarity:

-    hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond  # TODO rename to something more general
+    no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general
forward_step_fn: Callable = evo2_gpt_forward_step
no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond
forward_step_fn: Callable = evo2_gpt_forward_step
🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py around lines
252-253, the config field named hyena_no_weight_decay_cond_fn is misleadingly
specific to Hyena per the TODO; rename it to a more general name such as
no_weight_decay_cond_fn (or weight_decay_cond_fn) and update all references and
type annotations accordingly to preserve the same Callable assignment
(gpt_no_weight_decay_cond) and behavior; ensure imports, dataclass/constructor
defaults, serialization keys, and any external usages (attribute access, tests)
are updated to the new name to avoid breaking runtime code.

Comment on lines +431 to +439
elif model_type == "gpt":
if model_size not in GPT_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for GPT: {model_size}")
config = GPT_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step,
data_step_fn=hyena_predict_data_step,
)
else:
raise ValueError(f"Invalid model type: {model_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify GPT model reuses Hyena's forward and data step functions correctly.

The GPT model configuration reuses hyena_predict_forward_step and hyena_predict_data_step. While code reuse is good, ensure that these functions are fully compatible with GPT models, particularly regarding attention masks and packed sequences.


🏁 Script executed:

#!/bin/bash
# Verify that GPT models handle the same forward arguments as Hyena models
ast-grep --pattern 'class $GPT_CLASS {
  $$$
  def forward($$$) {
    $$$
  }
  $$$
}'

# Check if there are any GPT-specific forward step requirements
rg -n "gpt.*forward.*step" --type=py

Length of output: 359


🏁 Script executed:

#!/bin/bash
# Locate hyena_predict_forward_step definition and signature
rg -n -C3 "def hyena_predict_forward_step" -type py

# Locate hyena_predict_data_step definition and signature
rg -n -C3 "def hyena_predict_data_step" -type py

# Preview the GPT forward step implementation for comparison
sed -n '1,100p' sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py

Length of output: 4204


🏁 Script executed:

#!/bin/bash
# Locate hyena_predict_forward_step definition and signature
rg -n -C3 -t py 'def hyena_predict_forward_step' sub-packages/bionemo-evo2/src

# Locate hyena_predict_data_step definition and signature
rg -n -C3 -t py 'def hyena_predict_data_step' sub-packages/bionemo-evo2/src

Length of output: 1438


Use evo2_gpt_forward_step instead of hyena_predict_forward_step in the GPT config
Hyena’s forward_step unconditionally sets attention_mask=None, ignoring batch masks (docs.nvidia.com).

🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py around lines 431 to
439, the GPT config currently uses hyena_predict_forward_step which
unconditionally clears attention_mask; replace hyena_predict_forward_step with
evo2_gpt_forward_step when constructing the GPT config (keep data_step_fn as
hyena_predict_data_step unless otherwise required) and ensure
evo2_gpt_forward_step is imported at the top of the file or referenced correctly
so the GPT model respects batch masks.

Signed-off-by: Yang Zhang <[email protected]>
@dorotat-nv
Copy link
Collaborator

dorotat-nv commented Sep 3, 2025

@yzhang123 Docs are pointing towards Evo2 mamba when it is Evo2 GPT. What does it mean Evo2GPT? it is GPT with Evo2 type of data input? Can we have some description?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)

244-247: Rename hyena_no_weight_decay_cond_fn to a model-agnostic name.

Field name is misleading per the TODO; make it generic.

-    hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond  # TODO rename to something more general
+    no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond

Remember to update downstream attribute accesses.

🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (6)

53-62: Docstring still says “Mamba”; update to GPT.

This helper is for GPT, not Mamba.

-def evo2_gpt_forward_step(model, batch) -> torch.Tensor:
-    """Forward step function for Mamba models, similar to hyena_forward_step.
+def evo2_gpt_forward_step(model, batch) -> torch.Tensor:
+    """Forward step function for GPT models, similar to hyena_forward_step.
@@
-        model: The Mamba model
+        model: The GPT model

69-76: Attention mask/TE requirement: verify dataloaders or relax hard assert.

Hard-failing when TE isn’t present can break non-TE environments. Either ensure all dataloaders always provide attention_mask without TE, or gate with a clearer runtime error + fallback.

Proposed tweak:

-    if "attention_mask" not in batch:
-        assert HAVE_TE, (
-            "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \
-            This requires Transformer Engine's implementation of fused or flash attention."
-        )
+    if "attention_mask" not in batch:
+        if not HAVE_TE:
+            raise RuntimeError(
+                "attention_mask missing and Transformer Engine not available. "
+                "Provide attention_mask via dataloader or enable TE (fused/flash attention)."
+            )

200-214: Comment still references “Mamba-specific parameters”.

This isn’t Mamba-only.

-    # Mamba-specific parameters that should not have weight decay
+    # Parameters that should not have weight decay

225-233: Class name vs size mismatch (3B vs 8B).

Class is named LLama31ConfigEvoLoss3B but inherits Llama3Config8B and uses 8B shapes; docstring also says 8B. Rename for clarity and consistency.

-@dataclass
-class LLama31ConfigEvoLoss3B(llm.Llama3Config8B):
-    """Config for 8B hybrid GPT model."""
+@dataclass
+class Llama31ConfigEvoLoss8B(llm.Llama3Config8B):
+    """Config for Llama 3.1 8B GPT model with Evo2-style loss."""

Update references:

-GPT_MODEL_OPTIONS: dict[str, type[Llama3Config]] = {
-    "llama3_8b": LLama31ConfigEvoLoss3B,
-}
+GPT_MODEL_OPTIONS: dict[str, type[Llama3Config]] = {
+    "llama3_8b": Llama31ConfigEvoLoss8B,
+}

371-383: Full TE layer not supported path: consider a clearer error message.

Raise includes no remediation hints; add guidance (e.g., set use_transformer_engine_full_layer_spec=False) to shorten iteration time.

-            raise ValueError("use_transformer_engine_full_layer_spec is not supported in this config.")
+            raise ValueError(
+                "use_transformer_engine_full_layer_spec is not supported in this config. "
+                "Set use_transformer_engine_full_layer_spec=False or select a compatible config."
+            )

50-50: Unused variable _grad_accum_fusion_available.

It’s defined but not used; drop it or wire it into feature gating.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 52dacd6 and 47f3022.

📒 Files selected for processing (1)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (3)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
  • get_packed_seq_params (227-262)
sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss/embedding_variance.py (1)
  • SquaredErrorTargetedVarianceLoss (169-192)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (3)
  • get_inference_wrapper (81-113)
  • forward (116-146)
  • forward (171-275)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)

170-197: Move uppercase transform before loss computation and verify loss shape

  • Apply make_upper_case (and capture lowercase_mask) before calling super().forward so the loss is computed against the transformed labels.
  • Manually confirm that Evo2StyleMCoreGPTModel.forward returns an unreduced per-token loss tensor; if it returns a scalar, reweighted_cross_entropy will misbehave.
     @override
     def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, **kwargs):
-        _forward_out = super().forward(*args, labels=labels, loss_mask=loss_mask, **kwargs)
+        lowercase_mask = None
+        if labels is not None:
+            labels, lowercase_mask = make_upper_case(labels)
+        _forward_out = super().forward(*args, labels=labels, loss_mask=loss_mask, **kwargs)

Comment on lines +338 to +346
model_init_device_context = contextlib.nullcontext
if self.init_model_with_meta_device:
model_init_device_context = partial(torch.device, device="meta")

if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters:
kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)}
else:
kwargs = {}
with model_init_device_context():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Meta-device init context is incorrect; torch.device is not a context manager.

with model_init_device_context(): will fail when using partial(torch.device, device="meta"). Use torch.set_default_device("meta") which is a context manager in PyTorch 2.x.

-        model_init_device_context = contextlib.nullcontext
-        if self.init_model_with_meta_device:
-            model_init_device_context = partial(torch.device, device="meta")
+        model_init_device_context = contextlib.nullcontext
+        if self.init_model_with_meta_device:
+            # Set default device to 'meta' during construction.
+            model_init_device_context = partial(torch.set_default_device, "meta")

If your minimum PyTorch version doesn’t support set_default_device, consider accelerate.init_empty_weights() or Megatron’s meta-init utility instead.


🏁 Script executed:

#!/bin/bash
# Locate project requirement and configuration files
fd pyproject.toml setup.py requirements.txt environment.yml -t f

# Search for PyTorch version specifications
grep -R "torch" -n pyproject.toml setup.py requirements.txt environment.yml || true

Length of output: 549


Use torch.set_default_device instead of torch.device as context manager
In sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (lines 338–346), replace:

- model_init_device_context = contextlib.nullcontext
- if self.init_model_with_meta_device:
-     model_init_device_context = partial(torch.device, device="meta")
+ model_init_device_context = contextlib.nullcontext
+ if self.init_model_with_meta_device:
+     model_init_device_context = partial(torch.set_default_device, "meta")

torch.device isn’t a context manager; PyTorch 2.3.* supports torch.set_default_device.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_init_device_context = contextlib.nullcontext
if self.init_model_with_meta_device:
model_init_device_context = partial(torch.device, device="meta")
if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters:
kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)}
else:
kwargs = {}
with model_init_device_context():
model_init_device_context = contextlib.nullcontext
if self.init_model_with_meta_device:
model_init_device_context = partial(torch.set_default_device, "meta")
if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters:
kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)}
else:
kwargs = {}
with model_init_device_context():

@codecov-commenter
Copy link

codecov-commenter commented Sep 3, 2025

Codecov Report

❌ Patch coverage is 34.27230% with 140 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.78%. Comparing base (a29272f) to head (47f3022).
⚠️ Report is 4 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
...ckages/bionemo-evo2/src/bionemo/evo2/models/gpt.py 41.17% 90 Missing ⚠️
...ackages/bionemo-evo2/src/bionemo/evo2/run/train.py 7.14% 39 Missing ⚠️
...kages/bionemo-evo2/src/bionemo/evo2/run/predict.py 10.00% 9 Missing ⚠️
...ages/bionemo-evo2/src/bionemo/evo2/models/mamba.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1106      +/-   ##
==========================================
- Coverage   80.62%   79.78%   -0.84%     
==========================================
  Files         157      158       +1     
  Lines       11079    11283     +204     
==========================================
+ Hits         8932     9002      +70     
- Misses       2147     2281     +134     
Files with missing lines Coverage Δ
...ackages/bionemo-evo2/src/bionemo/evo2/run/infer.py 54.90% <ø> (ø)
...ages/bionemo-evo2/src/bionemo/evo2/models/mamba.py 77.33% <75.00%> (-0.45%) ⬇️
...kages/bionemo-evo2/src/bionemo/evo2/run/predict.py 61.68% <10.00%> (-2.54%) ⬇️
...ackages/bionemo-evo2/src/bionemo/evo2/run/train.py 13.62% <7.14%> (-0.73%) ⬇️
...ckages/bionemo-evo2/src/bionemo/evo2/models/gpt.py 41.17% <41.17%> (ø)

... and 1 file with indirect coverage changes

@yzhang123
Copy link
Collaborator Author

closing in favor of #1142 and #1109

@yzhang123 yzhang123 closed this Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants