-
Notifications
You must be signed in to change notification settings - Fork 89
enable predict.py for llama #1142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Yang Zhang <[email protected]>
WalkthroughAdds Llama support to the Evo2 prediction CLI using LLAMA_MODEL_OPTIONS, makes --output_dir required, introduces a new MambaPredictor with per-token log-prob handling and predict_step, and refactors branching into explicit hyena/mamba/llama paths; Llama inference reuses HyenaPredictor configuration. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as predict.py (CLI)
participant Config as Model Config Builder
participant Predictor as Predictor (Hyena / Mamba)
participant Trainer as Lightning Trainer
Note over CLI,Config: User supplies --model_type, --model_size, --output_dir
User->>CLI: invoke predict
CLI->>Config: validate model_type & size
alt model_type == "hyena"
Config-->>CLI: hyena config
CLI->>Predictor: instantiate HyenaPredictor
else model_type == "mamba"
Config-->>CLI: mamba config
CLI->>Predictor: instantiate MambaPredictor
else model_type == "llama"
Config-->>CLI: LLAMA_MODEL_OPTIONS -> hyena-style config
CLI->>Predictor: instantiate HyenaPredictor (llama config)
end
CLI->>Trainer: trainer.predict(Predictor, dataloaders)
loop per batch
Trainer->>Predictor: predict_step(batch)
Predictor-->>Trainer: {predictions, log_probabilities, tokens}
end
Trainer-->>User: save aggregated outputs to output_dir
sequenceDiagram
autonumber
participant M as MambaPredictor
participant Gather as Parallel Gather
participant Cache as Prediction Cache
M->>M: forward(input_ids, position_ids, ...)
alt output_log_prob_seqs enabled
M->>M: compute per-token log_probs
alt collapse == "sum" / "mean" / "per_token"
M->>M: collapse as requested
end
end
M->>Gather: emit logits/tokens/log_probs
Gather-->>M: gathered tensors
M->>Cache: store predictions/log_probs/tokens
M-->>Caller: return dict(predictions, log_probabilities, tokens)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks (1 passed, 2 warnings)❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Yang Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (5)
520-627
: MambaPredictor class placement needs adjustment.The MambaPredictor class is defined at the end of the file (after line 520), but it's referenced and used earlier in the code (line 466). This will cause a NameError at runtime.
Move the MambaPredictor class definition to be before its first usage. Apply this restructuring:
-# Create MambaPredictor class, similar to HyenaPredictor -class MambaPredictor(MambaModel, LightningPassthroughPredictionMixin): +# Move this entire class definition to line 288 (after PredictDataModule class) +class MambaPredictor(MambaModel, LightningPassthroughPredictionMixin):The class should be placed after the
PredictDataModule
class definition (around line 288) and before thepredict
function.
577-581
: Fix incorrect conditional logic in log probability calculation.The
else
branch on line 580 assumeslog_prob_collapse_option
is "mean", but it should also handle "per_token" option.Apply this diff to fix the logic:
- if self.log_prob_collapse_option == "sum": - sequence_log_prob = sum(token_log_probs) - else: # mean - sequence_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0 + if self.log_prob_collapse_option == "sum": + sequence_log_prob = sum(token_log_probs) + elif self.log_prob_collapse_option == "mean": + sequence_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0 + else: # per_token + sequence_log_prob = token_log_probs
543-546
: Unused instance variables in MambaPredictor.The instance variables
self.predictions
,self.log_probabilities
, andself.tokens
are initialized but never cleared between batches, which could lead to accumulating results across multiple predict calls.Consider either:
- Clearing these lists at the beginning of each predict_step
- Using local variables instead of instance variables
Apply this diff for option 2 (preferred):
- # Storage for the predictions - self.predictions = [] - self.log_probabilities = [] - self.tokens = []Then update the predict_step method to use local variables.
589-626
: Inconsistent return format in MambaPredictor.predict_step.The MambaPredictor returns a different data structure compared to HyenaPredictor. HyenaPredictor returns
log_probs_seqs
andseq_idx
, while MambaPredictor returnspredictions
,log_probabilities
, andtokens
. This inconsistency could cause issues with the PredictionWriter callback.Align the return format with HyenaPredictor for consistency:
- # Return as dict for PredictionWriter callback - return {"predictions": self.predictions, "log_probabilities": self.log_probabilities, "tokens": self.tokens} + # Return as dict for PredictionWriter callback (matching HyenaPredictor format) + if self.output_log_prob_seqs: + return {"log_probs_seqs": self.log_probabilities, "seq_idx": batch.get("seq_idx", torch.arange(len(batch["tokens"]))).cpu()} + else: + return {"token_logits": output.cpu(), "pad_mask": batch.get("loss_mask", torch.ones_like(batch["tokens"])).cpu(), "seq_idx": batch.get("seq_idx", torch.arange(len(batch["tokens"]))).cpu()}
592-594
: Empty batch handling creates incorrect data structure.The empty batch handling creates a batch with empty lists, which doesn't match the expected tensor format.
Apply this diff to properly handle empty batches:
- if batch == {}: - batch = {"tokens": [], "position_ids": [], "attention_mask": []} + if not batch or "tokens" not in batch: + return {}
🧹 Nitpick comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (1)
85-87
: Consider deduplicating model size choices.The current approach concatenates keys from all three model option dictionaries, which could lead to duplicate entries if models share size options (e.g., "7b" might be in multiple dictionaries).
Apply this diff to ensure unique choices:
- choices=sorted( - list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys()) - ), + choices=sorted( + set(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys())) + ),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
(5 hunks)
🔇 Additional comments (9)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (9)
40-40
: LGTM! Clean import addition for Llama support.The import is correctly placed alongside other model imports and follows the existing naming convention.
77-79
: LGTM! Model type choices properly extended.The choices are correctly expanded to include "llama" and the help text is appropriately updated.
94-95
: LGTM! Output directory is now appropriately required.Making the output directory required ensures that prediction results have a proper destination, preventing potential runtime errors.
422-430
: LGTM! Consistent implementation for Mamba model type.The explicit branching with
elif
improves code clarity and the configuration setup correctly reuses the Hyena forward and data step functions.
431-438
: LGTM! Llama configuration properly integrated.The Llama model path correctly validates against LLAMA_MODEL_OPTIONS and reuses the Hyena infrastructure for forward and data steps, which is appropriate given the architectural similarities.
465-471
: LGTM! Model instantiation handles all types correctly.The Mamba branch correctly instantiates MambaPredictor with appropriate parameters.
472-478
: LGTM! Llama correctly uses HyenaPredictor.The implementation correctly reuses HyenaPredictor for Llama models, which aligns with the architectural similarities mentioned in the AI summary.
568-575
: Remove false positive offset mismatch warning. The token – logit alignment is handled upstream inforward_step
, which prepends an EOS token and drops the first output as noted in the comments at lines 181–182, ensuringlog_probs[j]
correctly corresponds to token j.Likely an incorrect or invalid review comment.
405-409
: No action needed for Llama models
Theseq_len_interpolation_factor
override is specific to the Hyena “-1m” convolutional mixer and isn’t used or supported by the NeMo Llama configs (which rely on rotary embeddings for longer contexts).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (1)
207-207
: Fix end-of-file newline to satisfy pre-commitCI shows “End-of-file fixer modified file.” Add a trailing newline and re-run pre-commit locally.
Run:
#!/bin/bash pre-commit run -a git add -A && git commit -m "Fix EOF newline and pre-commit issues"
🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (6)
35-45
: Deduplicate helper and remove redundant CLI arg
- This helper duplicates small_training_llama_cmd from tests/bionemo/evo2/run/test_train.py. Prefer a shared test utility to avoid drift.
- "--limit-val-batches 1" is passed twice.
Apply to remove the duplicate flag:
def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""): cmd = ( f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} " - "--model-size 8B --num-layers 2 --limit-val-batches 1 " + "--model-size 8B --num-layers 2 " "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " - f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " + f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}" ) return cmdIf you’d like, I can move this helper into bionemo.testing (or a conftest.py fixture) and update both tests.
122-126
: Timeout is reasonable, but guard CI variability512s is generous; fine. If CI is occasionally slower, consider parametrizing via env (e.g., BIONEMO_TEST_TIMEOUT) with a default, to avoid editing code later.
136-136
: Avoid brittle stdout assertion about restore pathTraining logs can legitimately include “Restoring …” for various reasons; asserting its absence risks flakes. Prefer removing this check or converting it to a soft warning.
Apply:
- assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain + # Intentionally not asserting on restore messages to avoid log-format flakes.
143-151
: Make checkpoint discovery robust when multiple matches existAsserting exactly one match is brittle. Select the most recent matching checkpoint instead.
- assert matching_subfolders, ( - f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." - ) - assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." - checkpoint_path = matching_subfolders[0] + assert matching_subfolders, ( + f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." + ) + checkpoint_path = max(matching_subfolders, key=lambda p: p.stat().st_mtime)
192-205
: Don’t hardcode vocab dimension; derive from predictionsHardcoding 512 will break if LLAMA_MODEL_OPTIONS changes tokenizer/vocab. Read it from the tensor.
- preds = torch.load(pred_files[0]) + preds = torch.load(pred_files[0]) + vocab_size = preds["token_logits"][0].shape[1] @@ - assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512) + assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), vocab_size)
186-187
: Pathlib over os.path/glob for consistencySince you already use Path, prefer output_dir.glob("predictions__rank_*.pt") for readability.
- pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt")) + pred_files = list(output_dir.glob("predictions__rank_*.pt"))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (4)
sub-packages/bionemo-noodles/src/bionemo/noodles/nvfaidx.py (2)
NvFaidx
(108-253)keys
(228-229)sub-packages/bionemo-testing/src/bionemo/testing/data/fasta.py (1)
create_fasta_file
(28-63)sub-packages/bionemo-testing/src/bionemo/testing/subprocess_utils.py (1)
run_command_in_subprocess
(108-129)sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py (1)
small_training_llama_cmd
(99-107)
🪛 GitHub Actions: BioNeMo Framework CI
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py
[error] 1-1: End-of-file fixer modified file. Pre-commit hook failed (exit code 1).
🔇 Additional comments (2)
sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py (2)
32-32
: LGTM: Good reuse of subprocess helperUsing run_command_in_subprocess for training is appropriate here. Consider reusing it for prediction too (see below).
25-25
: pytest-timeout already declared in project
Thepytest-timeout
plugin (v2.2.0) is listed in requirements-test.txt, so thetimeout
marker is available.
# Run the command in a subshell, using the temporary directory as the current working directory. | ||
result = subprocess.run( | ||
command, | ||
shell=True, # Use the shell to interpret wildcards (e.g. SDH*) | ||
cwd=tmp_path, # Run in the temporary directory | ||
capture_output=True, # Capture stdout and stderr for debugging | ||
text=True, # Decode output as text | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use the subprocess helper for prediction to set MASTER_PORT and reduce flakes
predict_evo2 will initialize distributed backends; without a unique MASTER_PORT this can collide in CI. Reuse run_command_in_subprocess like you did for training.
- # Run the command in a subshell, using the temporary directory as the current working directory.
- result = subprocess.run(
- command,
- shell=True, # Use the shell to interpret wildcards (e.g. SDH*)
- cwd=tmp_path, # Run in the temporary directory
- capture_output=True, # Capture stdout and stderr for debugging
- text=True, # Decode output as text
- )
-
- # For debugging purposes, print the output if the test fails.
- if result.returncode != 0:
- sys.stderr.write("STDOUT:\n" + result.stdout + "\n")
- sys.stderr.write("STDERR:\n" + result.stderr + "\n")
-
- # Assert that the command completed successfully.
- assert result.returncode == 0, "predict_evo2 command with Llama model failed."
+ # Run the command in a subshell with an isolated MASTER_PORT.
+ _ = run_command_in_subprocess(command=command, path=str(tmp_path))
Also applies to: 177-184
🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py around lines
168-175 (and also apply the same change to lines 177-184), the test invokes
subprocess.run directly which can cause distributed backend MASTER_PORT
collisions in CI; replace the direct subprocess.run calls with the existing
run_command_in_subprocess helper used by training tests so it sets a unique
MASTER_PORT and consistent env handling. Use run_command_in_subprocess(command,
tmp_path) (or the helper’s exact signature used elsewhere) to execute the
command in the tmp_path and capture output/text, ensuring the helper provides
the env var and port isolation to reduce flakes.
/ok to test 8dd1d12 |
1 similar comment
/ok to test 8dd1d12 |
8dd1d12
to
0a9a039
Compare
0a9a039
to
a39524d
Compare
Description
enable evo2 predict.py for llama
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels:
Note
By default, the notebooks validation tests are skipped unless explicitly enabled.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to test
comment on the pull request to trigger CI. This will need to be done for each new commit.Usage
Pre-submit Checklist
Summary by CodeRabbit
New Features
Refactor
Tests