Skip to content

Conversation

yzhang123
Copy link
Collaborator

@yzhang123 yzhang123 commented Sep 10, 2025

Description

enable evo2 predict.py for llama

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels:

Note

By default, the notebooks validation tests are skipped unless explicitly enabled.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Usage

python predict.py --fasta [FILE] --ckpt-dir [Llama ckpt] --output-dir /tmp/ --model-size 8B --model-type llama 

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

  • New Features

    • Support for running predictions with Llama models.
    • New Mamba-backed prediction mode with optional per-token log probabilities and configurable aggregation (sum, mean, or per-token).
  • Refactor

    • Model-selection logic clarified to explicitly handle multiple model types.
    • CLI: output directory is now required; help text updated.
  • Tests

    • End-to-end test: train a small Llama model, run predictions, and validate per-sequence outputs.

Copy link
Contributor

coderabbitai bot commented Sep 10, 2025

Caution

Review failed

The head commit changed during the review from 0a9a039 to a39524d.

Walkthrough

Adds Llama support to the Evo2 prediction CLI using LLAMA_MODEL_OPTIONS, makes --output_dir required, introduces a new MambaPredictor with per-token log-prob handling and predict_step, and refactors branching into explicit hyena/mamba/llama paths; Llama inference reuses HyenaPredictor configuration.

Changes

Cohort / File(s) Summary of edits
Prediction CLI & branching
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
- Imported LLAMA_MODEL_OPTIONS, added "llama" to model_type choices and extended model-size options.
- Made --output_dir required and updated help text.
- Replaced generic else-branching with explicit elif branches for "mamba" and "llama".
Mamba predictor implementation
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
- Added public class MambaPredictor(MambaModel, LightningPassthroughPredictionMixin) with __init__, forward, and predict_step.
- forward computes per-token log probabilities and supports collapse modes (sum, mean, per_token).
- predict_step performs inference with context/tensor-parallel gathering, caches results, and returns dict with predictions, log_probabilities, and tokens.
Llama path using HyenaPredictor
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
- For model_type == "llama", validate model_size via LLAMA_MODEL_OPTIONS, construct a config using Hyena-style forward/data steps, and instantiate HyenaPredictor for inference.
Tests: full train-then-predict workflow
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py
- Added small_training_llama_cmd helper and test_predict_evo2_llama_runs which trains a small Llama model, locates its checkpoint, runs predict on a generated FASTA, verifies outputs exist, and asserts shapes/contents of token_logits, pad_mask, and seq_idx per sequence.
- Adjusted test_predict_evo2_runs signature to accept num_sequences and target_sequence_lengths.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant CLI as predict.py (CLI)
  participant Config as Model Config Builder
  participant Predictor as Predictor (Hyena / Mamba)
  participant Trainer as Lightning Trainer
  Note over CLI,Config: User supplies --model_type, --model_size, --output_dir

  User->>CLI: invoke predict
  CLI->>Config: validate model_type & size
  alt model_type == "hyena"
    Config-->>CLI: hyena config
    CLI->>Predictor: instantiate HyenaPredictor
  else model_type == "mamba"
    Config-->>CLI: mamba config
    CLI->>Predictor: instantiate MambaPredictor
  else model_type == "llama"
    Config-->>CLI: LLAMA_MODEL_OPTIONS -> hyena-style config
    CLI->>Predictor: instantiate HyenaPredictor (llama config)
  end

  CLI->>Trainer: trainer.predict(Predictor, dataloaders)
  loop per batch
    Trainer->>Predictor: predict_step(batch)
    Predictor-->>Trainer: {predictions, log_probabilities, tokens}
  end
  Trainer-->>User: save aggregated outputs to output_dir
Loading
sequenceDiagram
  autonumber
  participant M as MambaPredictor
  participant Gather as Parallel Gather
  participant Cache as Prediction Cache

  M->>M: forward(input_ids, position_ids, ...)
  alt output_log_prob_seqs enabled
    M->>M: compute per-token log_probs
    alt collapse == "sum" / "mean" / "per_token"
      M->>M: collapse as requested
    end
  end
  M->>Gather: emit logits/tokens/log_probs
  Gather-->>M: gathered tensors
  M->>Cache: store predictions/log_probs/tokens
  M-->>Caller: return dict(predictions, log_probabilities, tokens)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

I hop on keys with whiskers twitching bright,
Mamba counts each token through the night.
Llama borrows Hyena's steady stride,
Log-probs folded, summed, or left aside.
One burrow, three models—predictions take flight. 🐇✨

Pre-merge checks (1 passed, 2 warnings)

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The PR description is too brief and does not follow the repository template's requirement for a detailed "Description" section: it only states "enable evo2 predict.py for llama" while the template expects a clear summary of code changes, rationale, and impact. The Type of changes and Usage example are present and CI guidance appears, but the description omits file-level changes, new public APIs/classes, test details, documentation/migration notes, and confirmation of test outcomes. Because these required details are missing, the description is largely incomplete. Please expand the Description to list the key code changes (files modified/added, new classes such as MambaPredictor, and the llama prediction path), explain behavioral or API impacts and compatibility considerations, and summarize tests added/updated and their results. Update the Pre-submit Checklist to reflect documentation and test status and add any CI labels to control pipeline behavior. After updating the PR description with these concrete details, re-run the pre-merge checks.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title clearly and concisely describes that predict.py now supports Llama models, which is the central change of this PR.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch llama_predict

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzhang123 yzhang123 mentioned this pull request Sep 10, 2025
9 tasks
Signed-off-by: Yang Zhang <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (5)

520-627: MambaPredictor class placement needs adjustment.

The MambaPredictor class is defined at the end of the file (after line 520), but it's referenced and used earlier in the code (line 466). This will cause a NameError at runtime.

Move the MambaPredictor class definition to be before its first usage. Apply this restructuring:

-# Create MambaPredictor class, similar to HyenaPredictor
-class MambaPredictor(MambaModel, LightningPassthroughPredictionMixin):
+# Move this entire class definition to line 288 (after PredictDataModule class)
+class MambaPredictor(MambaModel, LightningPassthroughPredictionMixin):

The class should be placed after the PredictDataModule class definition (around line 288) and before the predict function.


577-581: Fix incorrect conditional logic in log probability calculation.

The else branch on line 580 assumes log_prob_collapse_option is "mean", but it should also handle "per_token" option.

Apply this diff to fix the logic:

-                if self.log_prob_collapse_option == "sum":
-                    sequence_log_prob = sum(token_log_probs)
-                else:  # mean
-                    sequence_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0
+                if self.log_prob_collapse_option == "sum":
+                    sequence_log_prob = sum(token_log_probs)
+                elif self.log_prob_collapse_option == "mean":
+                    sequence_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0
+                else:  # per_token
+                    sequence_log_prob = token_log_probs

543-546: Unused instance variables in MambaPredictor.

The instance variables self.predictions, self.log_probabilities, and self.tokens are initialized but never cleared between batches, which could lead to accumulating results across multiple predict calls.

Consider either:

  1. Clearing these lists at the beginning of each predict_step
  2. Using local variables instead of instance variables

Apply this diff for option 2 (preferred):

-        # Storage for the predictions
-        self.predictions = []
-        self.log_probabilities = []
-        self.tokens = []

Then update the predict_step method to use local variables.


589-626: Inconsistent return format in MambaPredictor.predict_step.

The MambaPredictor returns a different data structure compared to HyenaPredictor. HyenaPredictor returns log_probs_seqs and seq_idx, while MambaPredictor returns predictions, log_probabilities, and tokens. This inconsistency could cause issues with the PredictionWriter callback.

Align the return format with HyenaPredictor for consistency:

-        # Return as dict for PredictionWriter callback
-        return {"predictions": self.predictions, "log_probabilities": self.log_probabilities, "tokens": self.tokens}
+        # Return as dict for PredictionWriter callback (matching HyenaPredictor format)
+        if self.output_log_prob_seqs:
+            return {"log_probs_seqs": self.log_probabilities, "seq_idx": batch.get("seq_idx", torch.arange(len(batch["tokens"]))).cpu()}
+        else:
+            return {"token_logits": output.cpu(), "pad_mask": batch.get("loss_mask", torch.ones_like(batch["tokens"])).cpu(), "seq_idx": batch.get("seq_idx", torch.arange(len(batch["tokens"]))).cpu()}

592-594: Empty batch handling creates incorrect data structure.

The empty batch handling creates a batch with empty lists, which doesn't match the expected tensor format.

Apply this diff to properly handle empty batches:

-        if batch == {}:
-            batch = {"tokens": [], "position_ids": [], "attention_mask": []}
+        if not batch or "tokens" not in batch:
+            return {}
🧹 Nitpick comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (1)

85-87: Consider deduplicating model size choices.

The current approach concatenates keys from all three model option dictionaries, which could lead to duplicate entries if models share size options (e.g., "7b" might be in multiple dictionaries).

Apply this diff to ensure unique choices:

-        choices=sorted(
-            list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys())
-        ),
+        choices=sorted(
+            set(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys()))
+        ),
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 664a9b9 and a39524d.

📒 Files selected for processing (1)
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (5 hunks)
🔇 Additional comments (9)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (9)

40-40: LGTM! Clean import addition for Llama support.

The import is correctly placed alongside other model imports and follows the existing naming convention.


77-79: LGTM! Model type choices properly extended.

The choices are correctly expanded to include "llama" and the help text is appropriately updated.


94-95: LGTM! Output directory is now appropriately required.

Making the output directory required ensures that prediction results have a proper destination, preventing potential runtime errors.


422-430: LGTM! Consistent implementation for Mamba model type.

The explicit branching with elif improves code clarity and the configuration setup correctly reuses the Hyena forward and data step functions.


431-438: LGTM! Llama configuration properly integrated.

The Llama model path correctly validates against LLAMA_MODEL_OPTIONS and reuses the Hyena infrastructure for forward and data steps, which is appropriate given the architectural similarities.


465-471: LGTM! Model instantiation handles all types correctly.

The Mamba branch correctly instantiates MambaPredictor with appropriate parameters.


472-478: LGTM! Llama correctly uses HyenaPredictor.

The implementation correctly reuses HyenaPredictor for Llama models, which aligns with the architectural similarities mentioned in the AI summary.


568-575: Remove false positive offset mismatch warning. The token – logit alignment is handled upstream in forward_step, which prepends an EOS token and drops the first output as noted in the comments at lines 181–182, ensuring log_probs[j] correctly corresponds to token j.

Likely an incorrect or invalid review comment.


405-409: No action needed for Llama models
The seq_len_interpolation_factor override is specific to the Hyena “-1m” convolutional mixer and isn’t used or supported by the NeMo Llama configs (which rely on rotary embeddings for longer contexts).

@yzhang123 yzhang123 enabled auto-merge September 10, 2025 16:11
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (1)

207-207: Fix end-of-file newline to satisfy pre-commit

CI shows “End-of-file fixer modified file.” Add a trailing newline and re-run pre-commit locally.

Run:

#!/bin/bash
pre-commit run -a
git add -A && git commit -m "Fix EOF newline and pre-commit issues"
🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (6)

35-45: Deduplicate helper and remove redundant CLI arg

  • This helper duplicates small_training_llama_cmd from tests/bionemo/evo2/run/test_train.py. Prefer a shared test utility to avoid drift.
  • "--limit-val-batches 1" is passed twice.

Apply to remove the duplicate flag:

 def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
     cmd = (
         f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} "
-        "--model-size 8B --num-layers 2 --limit-val-batches 1 "
+        "--model-size 8B --num-layers 2 "
         "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback "
-        f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
+        f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
         f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
     )
     return cmd

If you’d like, I can move this helper into bionemo.testing (or a conftest.py fixture) and update both tests.


122-126: Timeout is reasonable, but guard CI variability

512s is generous; fine. If CI is occasionally slower, consider parametrizing via env (e.g., BIONEMO_TEST_TIMEOUT) with a default, to avoid editing code later.


136-136: Avoid brittle stdout assertion about restore path

Training logs can legitimately include “Restoring …” for various reasons; asserting its absence risks flakes. Prefer removing this check or converting it to a soft warning.

Apply:

-    assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
+    # Intentionally not asserting on restore messages to avoid log-format flakes.

143-151: Make checkpoint discovery robust when multiple matches exist

Asserting exactly one match is brittle. Select the most recent matching checkpoint instead.

-    assert matching_subfolders, (
-        f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
-    )
-    assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
-    checkpoint_path = matching_subfolders[0]
+    assert matching_subfolders, (
+        f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
+    )
+    checkpoint_path = max(matching_subfolders, key=lambda p: p.stat().st_mtime)

192-205: Don’t hardcode vocab dimension; derive from predictions

Hardcoding 512 will break if LLAMA_MODEL_OPTIONS changes tokenizer/vocab. Read it from the tensor.

-    preds = torch.load(pred_files[0])
+    preds = torch.load(pred_files[0])
+    vocab_size = preds["token_logits"][0].shape[1]
@@
-        assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512)
+        assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), vocab_size)

186-187: Pathlib over os.path/glob for consistency

Since you already use Path, prefer output_dir.glob("predictions__rank_*.pt") for readability.

-    pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt"))
+    pred_files = list(output_dir.glob("predictions__rank_*.pt"))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a39524d and 342d8be.

📒 Files selected for processing (1)
  • sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (4)
sub-packages/bionemo-noodles/src/bionemo/noodles/nvfaidx.py (2)
  • NvFaidx (108-253)
  • keys (228-229)
sub-packages/bionemo-testing/src/bionemo/testing/data/fasta.py (1)
  • create_fasta_file (28-63)
sub-packages/bionemo-testing/src/bionemo/testing/subprocess_utils.py (1)
  • run_command_in_subprocess (108-129)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py (1)
  • small_training_llama_cmd (99-107)
🪛 GitHub Actions: BioNeMo Framework CI
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py

[error] 1-1: End-of-file fixer modified file. Pre-commit hook failed (exit code 1).

🔇 Additional comments (2)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (2)

32-32: LGTM: Good reuse of subprocess helper

Using run_command_in_subprocess for training is appropriate here. Consider reusing it for prediction too (see below).


25-25: pytest-timeout already declared in project
The pytest-timeout plugin (v2.2.0) is listed in requirements-test.txt, so the timeout marker is available.

Comment on lines 168 to 175
# Run the command in a subshell, using the temporary directory as the current working directory.
result = subprocess.run(
command,
shell=True, # Use the shell to interpret wildcards (e.g. SDH*)
cwd=tmp_path, # Run in the temporary directory
capture_output=True, # Capture stdout and stderr for debugging
text=True, # Decode output as text
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use the subprocess helper for prediction to set MASTER_PORT and reduce flakes

predict_evo2 will initialize distributed backends; without a unique MASTER_PORT this can collide in CI. Reuse run_command_in_subprocess like you did for training.

-    # Run the command in a subshell, using the temporary directory as the current working directory.
-    result = subprocess.run(
-        command,
-        shell=True,  # Use the shell to interpret wildcards (e.g. SDH*)
-        cwd=tmp_path,  # Run in the temporary directory
-        capture_output=True,  # Capture stdout and stderr for debugging
-        text=True,  # Decode output as text
-    )
-
-    # For debugging purposes, print the output if the test fails.
-    if result.returncode != 0:
-        sys.stderr.write("STDOUT:\n" + result.stdout + "\n")
-        sys.stderr.write("STDERR:\n" + result.stderr + "\n")
-
-    # Assert that the command completed successfully.
-    assert result.returncode == 0, "predict_evo2 command with Llama model failed."
+    # Run the command in a subshell with an isolated MASTER_PORT.
+    _ = run_command_in_subprocess(command=command, path=str(tmp_path))

Also applies to: 177-184

🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py around lines
168-175 (and also apply the same change to lines 177-184), the test invokes
subprocess.run directly which can cause distributed backend MASTER_PORT
collisions in CI; replace the direct subprocess.run calls with the existing
run_command_in_subprocess helper used by training tests so it sets a unique
MASTER_PORT and consistent env handling. Use run_command_in_subprocess(command,
tmp_path) (or the helper’s exact signature used elsewhere) to execute the
command in the tmp_path and capture output/text, ensuring the helper provides
the env var and port isolation to reduce flakes.

@yzhang123
Copy link
Collaborator Author

/ok to test 8dd1d12

1 similar comment
@yzhang123
Copy link
Collaborator Author

/ok to test 8dd1d12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants