Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torch import Tensor

from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset
from bionemo.evo2.models.llama import LLAMA_MODEL_OPTIONS

# Add import for Mamba models
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel
Expand Down Expand Up @@ -73,23 +74,25 @@ def parse_args():
ap.add_argument(
"--model-type",
type=str,
choices=["hyena", "mamba"],
choices=["hyena", "mamba", "llama"],
default="hyena",
help="Model architecture family to use. Choose between 'hyena' and 'mamba'.",
help="Model architecture family to use. Choose between 'hyena', 'mamba', and 'llama'.",
)
ap.add_argument(
"--model-size",
type=str,
default="7b",
choices=sorted(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys())),
choices=sorted(
list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys())
),
help="Model size to use. Defaults to '7b'.",
)
# output args:
ap.add_argument(
"--output-dir",
type=Path,
default=None,
help="Output dir that will contain the generated text produced by the Evo2 model. If not provided, the output will be logged.",
required=True,
help="Output dir that will contain the generated text produced by the Evo2 model.",
)
ap.add_argument(
"--full-fp8",
Expand Down Expand Up @@ -416,7 +419,7 @@ def predict(
vortex_style_fp8=fp8 and not full_fp8,
**config_modifiers_init,
)
else: # mamba
elif model_type == "mamba": # mamba
if model_size not in MAMBA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Mamba: {model_size}")
config = MAMBA_MODEL_OPTIONS[model_size](
Expand All @@ -425,6 +428,14 @@ def predict(
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
**config_modifiers_init,
)
elif model_type == "llama":
if model_size not in LLAMA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Llama: {model_size}")
config = LLAMA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step,
data_step_fn=hyena_predict_data_step,
**config_modifiers_init,
)

trainer.strategy._setup_optimizers = False

Expand All @@ -451,13 +462,20 @@ def predict(
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)
else: # mamba
elif model_type == "mamba": # mamba
model = MambaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)
elif model_type == "llama":
model = HyenaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)

resume.setup(trainer, model) # this pulls weights from the starting checkpoint.

Expand Down
Loading