From d65d8fa3b3904736f8e336c64209708449799e75 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 8 Aug 2025 23:03:12 +0000 Subject: [PATCH 1/8] Initial commit of gpt variant of evo2 model. Signed-off-by: John St John --- .../src/bionemo/evo2/models/gpt.py | 388 ++++++++++++++++++ .../src/bionemo/evo2/run/train.py | 63 ++- 2 files changed, 440 insertions(+), 11 deletions(-) create mode 100644 sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py new file mode 100644 index 0000000000..6b6d1da630 --- /dev/null +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +from dataclasses import dataclass +from functools import partial +from typing import Callable + +import megatron.core.models.gpt.gpt_model +import torch +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from megatron.core.transformer.spec_utils import ModuleSpec +from nemo.collections import llm +from nemo.collections.llm.gpt.model.base import GPTModel, mtp_block_spec +from nemo.collections.llm.gpt.model.llama import Llama3Config, apply_rope_scaling +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case, reweighted_cross_entropy +from nemo.lightning import get_vocab_size +from nemo.utils.import_utils import safe_import +from typing_extensions import override + +from bionemo.evo2.utils.loss.embedding_variance import SquaredErrorTargetedVarianceLoss + + +_, HAVE_TE = safe_import("transformer_engine") + +# Gradient accumulation fusion may be enabled if available, for more information see: +# https://github.com/NVIDIA/Megatron-LM/blob/01945b98d1ea3a2acb5e8301e181a328104f4856/megatron/core/tensor_parallel/layers.py#L575 +# TODO: Clean this up with a getter and install instructions +_grad_accum_fusion_available = True +try: + import fused_weight_gradient_mlp_cuda # noqa: F401 # pylint: disable=unused-import +except ImportError: + _grad_accum_fusion_available = False + +logger = logging.getLogger(__name__) + + +def evo2_gpt_forward_step(model, batch) -> torch.Tensor: + """Forward step function for Mamba models, similar to hyena_forward_step. + + Args: + model: The Mamba model + batch: Dictionary containing input batch data + + Returns: + torch.Tensor: Output from the model forward pass + """ + forward_args = { + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "labels": batch["labels"], + "loss_mask": batch["loss_mask"], + } + forward_args["attention_mask"] = None + return model(**forward_args) + + +class Evo2GPTModel(GPTModel): + """Mamba model that extends GPTModel for integration with NeMo. + + Note that the loss calculation is handled by CustomMCoreMambaModel instead. + """ + + @override + def get_inference_wrapper( + self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=8192 + ) -> GPTInferenceWrapper: + """Gets the inference wrapper for the Mamba model.""" + # Find MCoreMambaModel instance + mcore_model = self.module + while mcore_model: + if isinstance(mcore_model, ()): + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or not isinstance( + mcore_model, (megatron.core.models.gpt.gpt_model.GPTModel, Evo2StyleMCoreGPTModel) + ): + raise ValueError("GPT model instance not found in the model structure.") + + vocab_size = None + if self.tokenizer is not None: + vocab_size = self.tokenizer.vocab_size + elif hasattr(self.config, "vocab_size"): + vocab_size = self.config.vocab_size + else: + raise ValueError("Unable to find vocab size.") + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=mcore_model.config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=vocab_size, + inference_max_seq_length=inference_max_seq_length, + ) + + model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) + return model_inference_wrapper + + @override + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + decoder_input: torch.Tensor | None = None, + inference_context=None, + packed_seq_params=None, + inference_params=None, + runtime_gather_output: bool | None = None, + loss_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass that delegates to CustomMCoreMambaModel, which handles loss calculation.""" + extra_kwargs = {"packed_seq_params": packed_seq_params} if packed_seq_params is not None else {} + output_tensor = self.module( + input_ids, + position_ids, + attention_mask, + decoder_input=decoder_input, + labels=labels, # Pass labels to the Megatron module + inference_params=inference_params, + inference_context=inference_context, + runtime_gather_output=runtime_gather_output, + loss_mask=loss_mask, # Pass loss_mask to the Megatron module + **extra_kwargs, + ) + + # Return whatever CustomMCoreMambaModel.forward returns + # (logits during inference, loss during training) + return output_tensor + + +# Custom MCoreMambaModel with reweighted loss calculation +class Evo2StyleMCoreGPTModel(megatron.core.models.gpt.gpt_model.GPTModel): + """Custom version of MCoreMambaModel that implements reweighted loss calculation. + + Note that this is similar to the HyenaModel for uppercase/lowercase handling. + """ + + def __init__(self, *args, **kwargs): + """Initializes `Evo2StyleMCoreMambaModel` with unique parameters for the Evo2 variant of `MCoreMambaModel`.""" + super().__init__(*args, **kwargs) + if self.config.use_targeted_variance_loss: + if not hasattr(self.config, "embedding_init_method_std"): + logger.warning("embedding_init_method_std is not supported in this config, please upgrade Megatron-LM") + # 1.0 is the suggested value for embedding_init_method_std from the + # [Spike No More](https://arxiv.org/abs/2312.16903) paper. + embedding_init_method_std: float = getattr(self.config, "embedding_init_method_std", 1.0) + self.targeted_variance_loss = SquaredErrorTargetedVarianceLoss( + loss_coeff=self.config.targeted_variance_loss_loss_coeff, + var_target=embedding_init_method_std**2, + ) + + @override + def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, **kwargs): + """Forward pass that delegates to Evo2StyleMCoreGPTModel, which handles loss calculation.""" + _forward_out = super().forward(*args, labels=labels, loss_mask=loss_mask, **kwargs) + if labels is None or not self.post_process: + # These are the two short-circuit cases in megatron.core.models.gpt.gpt_model.GPTModel.forward + # 1. labels is None + # -> return the logits transposed to batch_size x seq_len x vocab_size + # 2. not self.post_process + # -> return the hidden states. + return _forward_out + # Now that the above is false, we know that _forward_out is the loss, as in: + # loss = self.compute_language_model_loss(labels, logits) + loss = _forward_out + + labels, lowercase_mask = make_upper_case(labels) + normalize_per_batch = True if self.config.to_upper == "normalized_weighted" else False + loss = reweighted_cross_entropy( + loss, + (labels, loss_mask, lowercase_mask), + lowercase_weight=self.config.lowercase_loss_reweighting, + normalize_per_batch=normalize_per_batch, + ) + if self.training and self.config.use_targeted_variance_loss: + # Only use this in training, not validation etc. + var_loss = self.targeted_variance_loss(self.embedding.word_embeddings.weight) + loss += var_loss + return loss + + +def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False): + """Condition for no weight decay for Mamba parameters. + + Note that this follows the same pattern as in the original Mamba implementation. + """ + # Mamba-specific parameters that should not have weight decay + if ("embedding" in name and exclude_embeddings) or getattr(param, "_no_weight_decay", False): + no_wd = True + # All other parameters - use default MCore behavior: + # Do not regularize biases and norm parameters + # (See megatron.core.optimizer._get_pram_groups) + # TODO exclude embeddings + else: + no_wd = name.endswith(".bias") or len(param.shape) == 1 + return no_wd + + +def gpt_no_weight_decay_cond_with_embeddings(name, param): + """Condition for no weight decay for Mamba parameters with embeddings. + + Note that this follows the same pattern as in the original Mamba implementation but also skips WD on embeddings. + """ + return gpt_no_weight_decay_cond(name, param, exclude_embeddings=True) + + +@dataclass +class LLama31ConfigEvoLoss3B(llm.Llama3Config8B): + """Config for 8B hybrid Mamba model.""" + + # RoPE/context length related block: + rotary_base: int = 500_000 + seq_length: int = 8192 + old_context_len: int = 8192 # should be set/updated based on the loaded checkpoint's seq_length if fine-tuning. + scale_factor: float = 1.0 # should be the ratio between the old context length and the new seq_length + low_freq_factor: float = 1.0 # this factor can be left as is when extending the context length + high_freq_factor: float = 4.0 # this factor can be left as is when extending the context length + + # vocab_size: int = 512 + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 14336 + num_attention_heads: int = 32 + embedding_init_method_std: float = 1.0 + + init_method_std: float = 0.02 + hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general + forward_step_fn: Callable = evo2_gpt_forward_step + + layernorm_embeddings: bool = False + # If set to true, use targeted variance loss which encourages the word embedding weight variances + # to be close to a target value (1.0). + share_embeddings_and_output_weights: bool = False + use_targeted_variance_loss: bool = False + targeted_variance_loss_loss_coeff: float = 0.1 + spike_no_more_embedding_init: bool = True + to_upper: str = "normalized_weighted" + lowercase_loss_reweighting: float = 0.1 + + def __post_init__(self): + """Post-init logic for Evo2 to enable backwards compatibility with old configs.""" + # Specific post_init logic for Evo2 to enable backwards compatibility with old configs. + if not hasattr(self, "embedding_init_method_std"): + raise ValueError("embedding_init_method_std is not supported in this config, please upgrade Megatron-LM") + if self.spike_no_more_embedding_init and self.embedding_init_method_std is None: + logger.warning( + "spike_no_more_embedding_init is deprecated, please set " + "embedding_init_method_std=[desired_stdev] in the future. To get the old behavior set to 1.0. " + "For now setting to 1.0." + ) + self.embedding_init_method_std = 1.0 + # Continue with the remaining post-init logic defined in NemotronHConfigBase and/or TransformerConfig. + super().__post_init__() + + @override + def configure_model( + self, tokenizer, pre_process=None, post_process=None, vp_stage: int | None = None + ) -> Evo2StyleMCoreGPTModel: + """Configure and instantiate a Megatron Core Llama 3.1 model. + + Extends the base configuration with Llama 3.1 specific RoPE scaling. + + Args: + tokenizer: Tokenizer used with the model + pre_process: Whether to include pre-processing in the model + post_process: Whether to include post-processing in the model + vp_stage: Virtual pipeline parallel stage (or None if not using virtual pipeline parallelism) + + Returns: + MCoreGPTModel: Configured Megatron Core GPT model instance + """ + if self.enable_cuda_graph: + assert HAVE_TE, "Transformer Engine is required for cudagraphs." + assert getattr(self, "use_te_rng_tracker", False), ( + "Transformer engine's RNG tracker is required for cudagraphs, it can be " + "enabled with use_te_rng_tracker=True'." + ) + + vp_size = self.virtual_pipeline_model_parallel_size + is_pipeline_asymmetric = getattr(self, "account_for_embedding_in_pipeline_split", False) or getattr( + self, "account_for_loss_in_pipeline_split", False + ) + is_pipeline_asymmetric |= ( + getattr(self, "num_layers_in_first_pipeline_stage", None) + or getattr(self, "num_layers_in_last_pipeline_stage", None) + ) is not None + is_flexible_pp_layout = is_pipeline_asymmetric or ( + getattr(self, "pipeline_model_parallel_layout", None) is not None + ) + if vp_size and not is_flexible_pp_layout: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + import inspect + + # During fake lightning initialization, pass 0 to bypass the assertion that vp_stage must be + # non-None when using virtual pipeline model parallelism + vp_stage = vp_stage or 0 + + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + # Check if the transformer_layer_spec function accepts vp_stage parameter + if "vp_stage" in inspect.signature(transformer_layer_spec).parameters: + transformer_layer_spec = transformer_layer_spec(self, vp_stage=vp_stage) + else: + transformer_layer_spec = transformer_layer_spec(self) + + if self.vocab_size is not None: + vocab_size = self.vocab_size + if tokenizer is not None: + logging.info( + f"Use preset vocab_size: {vocab_size}, original vocab_size: {tokenizer.vocab_size}, dummy tokens:" + f" {vocab_size - tokenizer.vocab_size}." + ) + else: + vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) + # Initialize model as meta data instead of allocating data on a device + model_init_device_context = contextlib.nullcontext + if self.init_model_with_meta_device: + model_init_device_context = partial(torch.device, device="meta") + + if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters: + kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)} + else: + kwargs = {} + with model_init_device_context(): + model = Evo2StyleMCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process + or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage), + post_process=post_process + or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + vp_stage=vp_stage, + **kwargs, + ) + + # If using full TE layer, need to set TP, CP group since the module call + # is not routed through megatron core, which normally handles passing the + # TP, CP group to the TE modules. + # Deep iterate but skip self to avoid infinite recursion. + if self.use_transformer_engine_full_layer_spec: + raise ValueError("use_transformer_engine_full_layer_spec is not supported in this config.") + + # Apply rope scaling for Llama3.1 model + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + factor=self.scale_factor, + low_freq_factor=self.low_freq_factor, + high_freq_factor=self.high_freq_factor, + old_context_len=self.old_context_len, + ) + return model + + +# Dictionary mapping model size names to config classes +GPT_MODEL_OPTIONS: dict[str, type[Llama3Config]] = { + "llama3_8b": LLama31ConfigEvoLoss3B, +} diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index 09f44dc6ee..02b6baf5ca 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -17,6 +17,7 @@ # limitations under the License. import argparse +import logging from pathlib import Path from typing import List, Optional @@ -46,6 +47,7 @@ from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.utils.exp_manager import TimingCallback +from bionemo.evo2.models.gpt import GPT_MODEL_OPTIONS, Evo2GPTModel, gpt_no_weight_decay_cond_with_embeddings from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings from bionemo.evo2.run.peft import Evo2LoRA from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings @@ -54,6 +56,7 @@ from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger +logger = logging.getLogger(__name__) torch._dynamo.config.suppress_errors = True @@ -180,7 +183,9 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: parser.add_argument( "--model-size", type=str, - choices=sorted(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys())), + choices=sorted( + list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(GPT_MODEL_OPTIONS.keys()) + ), default="7b", help="Model size/configuration to use. Options depend on the selected model-type.", ) @@ -472,6 +477,11 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: default=5, help="Number of best checkpoints to keep. Set to -1 to save all checkpoints.", ) + parser.add_argument( + "--old-context-len", + type=int, + help="Old context length for the GPT model. This is used to set the old context length for the GPT model when you supply a ckpt_dir.", + ) parser.add_argument( "--metric-to-monitor-for-checkpoints", type=str, @@ -586,9 +596,10 @@ def train(args: argparse.Namespace) -> nl.Trainer: "distribute_saved_activations": False if args.sequence_parallel else True, "cross_entropy_loss_fusion": args.cross_entropy_loss_fusion, "fp32_residual_connection": not args.no_fp32_residual_connection, - "add_bias_output": args.add_bias_output, **activation_checkpointing_args, } + if args.add_bias_output: + config_modifiers_init["add_bias_output"] = args.add_bias_output if args.spike_no_more_embedding_init: config_modifiers_init["embedding_init_method_std"] = 1.0 # When using spike_no_more_embedding_init, we don't want to share embeddings and outputs. @@ -607,6 +618,8 @@ def train(args: argparse.Namespace) -> nl.Trainer: model_type = "hyena" elif args.model_size in MAMBA_MODEL_OPTIONS: model_type = "mamba" + elif args.model_size in GPT_MODEL_OPTIONS: + model_type = "gpt" else: raise ValueError(f"Invalid model size: {args.model_size}") @@ -625,17 +638,45 @@ def train(args: argparse.Namespace) -> nl.Trainer: lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path) model = llm.HyenaModel(model_config, tokenizer=data_module.tokenizer, model_transform=lora_transform) - else: # mamba + elif model_type == "mamba": # mamba if args.no_weight_decay_embeddings: config_modifiers_init["hyena_no_weight_decay_cond_fn"] = mamba_no_weight_decay_cond_with_embeddings config_modifiers_init["lowercase_loss_reweighting"] = args.mamba_lowercase_loss_weight if args.model_size not in MAMBA_MODEL_OPTIONS: raise ValueError(f"Invalid model size for Mamba: {args.model_size}") - add_bias_output = config_modifiers_init.pop("add_bias_output") - if add_bias_output: - raise ValueError("Bias output is not supported for Mamba models.") model_config = MAMBA_MODEL_OPTIONS[args.model_size](**config_modifiers_init) model = MambaModel(model_config, tokenizer=data_module.tokenizer) + elif model_type == "gpt": + assert args.no_fp32_residual_connection, ( + "GPT models do not support fp32 residual connection, please run with --no-fp32-residual-connection." + ) + config_modifiers_init["lowercase_loss_reweighting"] = args.mamba_lowercase_loss_weight + if args.no_weight_decay_embeddings: + config_modifiers_init["hyena_no_weight_decay_cond_fn"] = gpt_no_weight_decay_cond_with_embeddings + if args.model_size not in GPT_MODEL_OPTIONS: + raise ValueError(f"Invalid model size for GPT: {args.model_size}") + if args.ckpt_dir is None or args.old_context_len: + # Set the old context length based on the initial pre-training run seq_length + # when you supply a ckpt_dir, assume that we will use whatever that value was set to previously + # for rope extension. + old_context_len = args.old_context_len or args.seq_length # set to the seq_length if not supplied + else: + if not args.old_context_len: + old_context_len = args.seq_length + logger.warning( + "No old context length supplied, using the seq_length as the old context length. " + "This is not recommended and if training at a different context length the RoPE scaling factors " + "will be incorrect. Please supply the old context length when fine-tuning especially if you are " + "extending the context length." + ) + else: + old_context_len = args.old_context_len + config_modifiers_init["old_context_len"] = old_context_len + # Set scale factor to the ratio between the old context length and the new seq_length, or at least 1.0 + config_modifiers_init["scale_factor"] = args.seq_length / max(old_context_len, args.seq_length) + + model_config = GPT_MODEL_OPTIONS[args.model_size](**config_modifiers_init) + model = Evo2GPTModel(model_config, tokenizer=data_module.tokenizer) # Setup callbacks. callbacks = [ @@ -658,7 +699,7 @@ def train(args: argparse.Namespace) -> nl.Trainer: flop_meas_callback = FLOPsMeasurementCallback( model_config, data_module, - "hyena", + model_type, ) callbacks.append(flop_meas_callback) @@ -717,7 +758,7 @@ def train(args: argparse.Namespace) -> nl.Trainer: f"-GBS{global_batch_size}-MBS{args.micro_batch_size}-SkipLossRenorm{args.no_renormalize_loss}" f"-NOAC{args.no_activation_checkpointing}-SELAC{args.selective_activation_checkpointing}" f"-ACRNL{model_config.recompute_num_layers}" - f"-PAT{model_config.hybrid_override_pattern}" + f"-PAT{getattr(model_config, 'hybrid_override_pattern', 'none')}" f"-F32R{model_config.fp32_residual_connection}" f"-FCE{model_config.cross_entropy_loss_fusion}" f"-AIC{average_in_collective}" @@ -737,9 +778,9 @@ def train(args: argparse.Namespace) -> nl.Trainer: f"-TVL{args.use_targeted_variance_loss}" f"-NODES{args.num_nodes}-FP8{args.fp8}" ) - if model_type == "mamba": - # Include this setting for mamba models. - wandb_run_name += f"-LLW{args.mamba_lowercase_loss_weight}" + if model_type in {"mamba", "gpt"}: + # Include this setting for mamba/GPT models. + wandb_run_name += f"-LLW{config_modifiers_init.get('lowercase_loss_reweighting')}" wandb_config: Optional[WandbConfig] = ( None From 6a30e47987a462fe310de659bb364e37d8e880b0 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 11 Aug 2025 21:36:07 +0000 Subject: [PATCH 2/8] Allow user to have precise control over embedding init std Signed-off-by: John St John --- .../src/bionemo/evo2/run/train.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index 02b6baf5ca..c940a299ac 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -364,6 +364,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: "or with --use-targeted-variance-loss to maintain a 1.0 variance during training even with weight decay. This " "also turns off shared weights between embeddings and outputs.", ) + parser.add_argument( + "--embedding-init-std", + type=float, + help="Embedding init std. This can be used in place of --spike-no-more-embedding-init by setting the value " + "to 1.0. Use this or --spike-no-more-embedding-init, not both.", + ) parser.add_argument( "--no-weight-decay-embeddings", action="store_true", @@ -600,9 +606,17 @@ def train(args: argparse.Namespace) -> nl.Trainer: } if args.add_bias_output: config_modifiers_init["add_bias_output"] = args.add_bias_output - if args.spike_no_more_embedding_init: - config_modifiers_init["embedding_init_method_std"] = 1.0 - # When using spike_no_more_embedding_init, we don't want to share embeddings and outputs. + if args.embedding_init_std is not None or args.spike_no_more_embedding_init: + if args.embedding_init_std is not None and not args.spike_no_more_embedding_init: + config_modifiers_init["embedding_init_method_std"] = args.embedding_init_std + elif args.spike_no_more_embedding_init and args.embedding_init_std is None: + config_modifiers_init["embedding_init_method_std"] = 1.0 + else: + logger.warning( + "Both --spike-no-more-embedding-init and --embedding-init-std are set. Using --embedding-init-std" + ) + config_modifiers_init["embedding_init_method_std"] = args.embedding_init_std + # When using different embedding init methods, we don't want to share embeddings and outputs. config_modifiers_init["share_embeddings_and_output_weights"] = False if args.ffn_hidden_size: config_modifiers_init["ffn_hidden_size"] = args.ffn_hidden_size From b34ffb09b47599171bd89a3e17817efd5c60c5b5 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 11 Aug 2025 21:50:53 +0000 Subject: [PATCH 3/8] Log the embedding init std Signed-off-by: John St John --- sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index c940a299ac..acc7782097 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -787,7 +787,7 @@ def train(args: argparse.Namespace) -> nl.Trainer: f"-B1{args.adam_beta1}-B2{args.adam_beta2}-EPS{args.adam_eps}" f"-PAO{args.use_precision_aware_optimizer}" f"-B16MG{args.bf16_main_grads}" - f"-EWD{args.no_weight_decay_embeddings}-SNI{args.spike_no_more_embedding_init}" + f"-EWD{args.no_weight_decay_embeddings}-EMBI{model_config.embedding_init_method_std}" f"-OGR{args.overlap_grad_reduce}-OPG{args.overlap_param_gather}" f"-TVL{args.use_targeted_variance_loss}" f"-NODES{args.num_nodes}-FP8{args.fp8}" From 40b17c051176804b98e57bf497d9192f9b9eaab4 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 12 Aug 2025 00:30:50 +0000 Subject: [PATCH 4/8] Handle cu_seqlens and packed sequences in forward Signed-off-by: John St John --- .../bionemo-evo2/src/bionemo/evo2/models/gpt.py | 14 ++++++++++++-- .../src/bionemo/evo2/models/mamba.py | 16 ++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py index 6b6d1da630..6f8d0b4e83 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py @@ -26,7 +26,7 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig from megatron.core.transformer.spec_utils import ModuleSpec from nemo.collections import llm -from nemo.collections.llm.gpt.model.base import GPTModel, mtp_block_spec +from nemo.collections.llm.gpt.model.base import GPTModel, get_packed_seq_params, mtp_block_spec from nemo.collections.llm.gpt.model.llama import Llama3Config, apply_rope_scaling from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case, reweighted_cross_entropy from nemo.lightning import get_vocab_size @@ -66,7 +66,17 @@ def evo2_gpt_forward_step(model, batch) -> torch.Tensor: "labels": batch["labels"], "loss_mask": batch["loss_mask"], } - forward_args["attention_mask"] = None + if "attention_mask" not in batch: + assert HAVE_TE, ( + "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \ + This requires Transformer Engine's implementation of fused or flash attention." + ) + else: + forward_args["attention_mask"] = batch["attention_mask"] + + if "cu_seqlens" in batch: + forward_args["packed_seq_params"] = get_packed_seq_params(batch) + return model(**forward_args) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py index 13cf721754..5299a10db5 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py @@ -25,17 +25,19 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import WrappedTensor, deprecate_inference_params -from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step +from nemo.collections.llm.gpt.model.base import GPTModel, get_packed_seq_params, gpt_data_step from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case, reweighted_cross_entropy from nemo.collections.llm.gpt.model.ssm import ( NemotronHConfigBase, ) from nemo.lightning import get_vocab_size +from nemo.utils.import_utils import safe_import from typing_extensions import override from bionemo.evo2.utils.loss.embedding_variance import SquaredErrorTargetedVarianceLoss +_, HAVE_TE = safe_import("transformer_engine") logger = logging.getLogger(__name__) @@ -55,7 +57,17 @@ def mamba_forward_step(model, batch) -> torch.Tensor: "labels": batch["labels"], "loss_mask": batch["loss_mask"], } - forward_args["attention_mask"] = None + if "attention_mask" not in batch: + assert HAVE_TE, ( + "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \ + This requires Transformer Engine's implementation of fused or flash attention." + ) + else: + forward_args["attention_mask"] = batch["attention_mask"] + + if "cu_seqlens" in batch: + forward_args["packed_seq_params"] = get_packed_seq_params(batch) + return model(**forward_args) From bebfe84e39d0ad52b414438369331b519d870e33 Mon Sep 17 00:00:00 2001 From: John St John Date: Thu, 28 Aug 2025 19:59:11 +0000 Subject: [PATCH 5/8] Add ability to change the rope base to train Signed-off-by: John St John --- sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index 6f0cbafa8b..3524821645 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -488,6 +488,11 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: type=int, help="Old context length for the GPT model. This is used to set the old context length for the GPT model when you supply a ckpt_dir.", ) + parser.add_argument( + "--rope-base", + type=int, + help="RoPE base override. If set, will use this value for the RoPE base instead of the default value", + ) parser.add_argument( "--metric-to-monitor-for-checkpoints", type=str, @@ -604,6 +609,8 @@ def train(args: argparse.Namespace) -> nl.Trainer: "fp32_residual_connection": not args.no_fp32_residual_connection, **activation_checkpointing_args, } + if args.rope_base: + config_modifiers_init["rotary_base"] = args.rope_base if args.add_bias_output: config_modifiers_init["add_bias_output"] = args.add_bias_output if args.embedding_init_std is not None or args.spike_no_more_embedding_init: From bde26b97ac705b627057aab8add18b0ab8b32823 Mon Sep 17 00:00:00 2001 From: John St John Date: Thu, 28 Aug 2025 20:24:56 +0000 Subject: [PATCH 6/8] Add ability to also override the scale factor Signed-off-by: John St John --- .../bionemo-evo2/src/bionemo/evo2/run/train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index 3524821645..07aee7d0a6 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -486,7 +486,14 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: parser.add_argument( "--old-context-len", type=int, - help="Old context length for the GPT model. This is used to set the old context length for the GPT model when you supply a ckpt_dir.", + help="Old context length for the GPT model. This is used to set the old context length for the GPT model " + "when you supply a ckpt_dir.", + ) + parser.add_argument( + "--scale-factor", + type=float, + help="Scale factor override. If set, will use this value for the scale factor instead of the default value " + "which is computed as the ratio between the old context length and the new seq_length.", ) parser.add_argument( "--rope-base", @@ -694,7 +701,12 @@ def train(args: argparse.Namespace) -> nl.Trainer: old_context_len = args.old_context_len config_modifiers_init["old_context_len"] = old_context_len # Set scale factor to the ratio between the old context length and the new seq_length, or at least 1.0 - config_modifiers_init["scale_factor"] = args.seq_length / max(old_context_len, args.seq_length) + if args.scale_factor: + # Use the user supplied scale factor if they know what they are doing, otherwise just use the default + # which is a ratio between the old context length and the new seq_length. + config_modifiers_init["scale_factor"] = args.scale_factor + else: + config_modifiers_init["scale_factor"] = args.seq_length / max(old_context_len, args.seq_length) model_config = GPT_MODEL_OPTIONS[args.model_size](**config_modifiers_init) model = Evo2GPTModel(model_config, tokenizer=data_module.tokenizer) From 8d22da636cd182ec700c0663f0f4fce0cafb113c Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 3 Sep 2025 09:13:12 -0700 Subject: [PATCH 7/8] add gpt to predict.py and infer.py Signed-off-by: Yang Zhang --- .../src/bionemo/evo2/models/gpt.py | 18 +++++------- .../src/bionemo/evo2/run/infer.py | 6 ++-- .../src/bionemo/evo2/run/predict.py | 29 +++++++++++++++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py index 6f8d0b4e83..c810799a82 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py @@ -91,15 +91,13 @@ def get_inference_wrapper( self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=8192 ) -> GPTInferenceWrapper: """Gets the inference wrapper for the Mamba model.""" - # Find MCoreMambaModel instance - mcore_model = self.module - while mcore_model: - if isinstance(mcore_model, ()): + model = self + while model is not None: + if getattr(model, "module", None) is not None: + model = model.module + else: break - mcore_model = getattr(mcore_model, "module", None) - if mcore_model is None or not isinstance( - mcore_model, (megatron.core.models.gpt.gpt_model.GPTModel, Evo2StyleMCoreGPTModel) - ): + if not isinstance(model, megatron.core.models.gpt.gpt_model.GPTModel): raise ValueError("GPT model instance not found in the model structure.") vocab_size = None @@ -111,14 +109,14 @@ def get_inference_wrapper( raise ValueError("Unable to find vocab size.") inference_wrapper_config = InferenceWrapperConfig( - hidden_size=mcore_model.config.hidden_size, + hidden_size=model.config.hidden_size, params_dtype=params_dtype, inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, padded_vocab_size=vocab_size, inference_max_seq_length=inference_max_seq_length, ) - model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) + model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config) return model_inference_wrapper @override diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py index 211e7768b4..16f1e6e48e 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py @@ -86,14 +86,12 @@ def parse_args(): ) ap.add_argument( "--fp8", - type=bool, action="store_true", default=False, help="Whether to use vortex style FP8. Defaults to False.", ) ap.add_argument( "--flash-decode", - type=bool, action="store_true", default=False, help="Whether to use flash decode. Defaults to True.", @@ -173,8 +171,8 @@ def infer( path=ckpt_dir, trainer=trainer, params_dtype=torch.bfloat16, - inference_batch_times_seqlen_threshold=8192, # TODO - inference_max_seq_length=8192, # TODO + inference_batch_times_seqlen_threshold=len(prompt) + max_new_tokens, # TODO + inference_max_seq_length=len(prompt) + max_new_tokens, # TODO recompute_granularity=None, recompute_num_layers=None, recompute_method=None, diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py index 951c242a20..32ae78e8eb 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py @@ -37,6 +37,7 @@ from torch import Tensor from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset +from bionemo.evo2.models.gpt import GPT_MODEL_OPTIONS # Add import for Mamba models from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel @@ -73,15 +74,17 @@ def parse_args(): ap.add_argument( "--model-type", type=str, - choices=["hyena", "mamba"], + choices=["hyena", "mamba", "gpt"], default="hyena", - help="Model architecture family to use. Choose between 'hyena' and 'mamba'.", + help="Model architecture family to use. Choose between 'hyena', 'mamba', and 'gpt'.", ) ap.add_argument( "--model-size", type=str, default="7b", - choices=sorted(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys())), + choices=sorted( + list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(GPT_MODEL_OPTIONS.keys()) + ), help="Model size to use. Defaults to '7b'.", ) # output args: @@ -416,7 +419,7 @@ def predict( vortex_style_fp8=fp8 and not full_fp8, **config_modifiers_init, ) - else: # mamba + elif model_type == "mamba": # mamba if model_size not in MAMBA_MODEL_OPTIONS: raise ValueError(f"Invalid model size for Mamba: {model_size}") config = MAMBA_MODEL_OPTIONS[model_size]( @@ -425,6 +428,15 @@ def predict( distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True, **config_modifiers_init, ) + elif model_type == "gpt": + if model_size not in GPT_MODEL_OPTIONS: + raise ValueError(f"Invalid model size for GPT: {model_size}") + config = GPT_MODEL_OPTIONS[model_size]( + forward_step_fn=hyena_predict_forward_step, + data_step_fn=hyena_predict_data_step, + ) + else: + raise ValueError(f"Invalid model type: {model_type}") trainer.strategy._setup_optimizers = False @@ -451,13 +463,20 @@ def predict( output_log_prob_seqs=output_log_prob_seqs, log_prob_collapse_option=log_prob_collapse_option, ) - else: # mamba + elif model_type == "mamba": # mamba model = MambaPredictor( config, tokenizer=tokenizer, output_log_prob_seqs=output_log_prob_seqs, log_prob_collapse_option=log_prob_collapse_option, ) + elif model_type == "gpt": + model = HyenaPredictor( + config, + tokenizer=tokenizer, + output_log_prob_seqs=output_log_prob_seqs, + log_prob_collapse_option=log_prob_collapse_option, + ) resume.setup(trainer, model) # this pulls weights from the starting checkpoint. From 47f3022a0d73fc6c7093eeeda3fc24f5adc11aba Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 3 Sep 2025 10:56:12 -0700 Subject: [PATCH 8/8] fix docs Signed-off-by: Yang Zhang --- .../src/bionemo/evo2/models/gpt.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py index c810799a82..61cc04381c 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py @@ -81,16 +81,13 @@ def evo2_gpt_forward_step(model, batch) -> torch.Tensor: class Evo2GPTModel(GPTModel): - """Mamba model that extends GPTModel for integration with NeMo. - - Note that the loss calculation is handled by CustomMCoreMambaModel instead. - """ + """GPT model that extends GPTModel for integration with NeMo.""" @override def get_inference_wrapper( self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=8192 ) -> GPTInferenceWrapper: - """Gets the inference wrapper for the Mamba model.""" + """Gets the inference wrapper for the GPT model.""" model = self while model is not None: if getattr(model, "module", None) is not None: @@ -133,7 +130,7 @@ def forward( runtime_gather_output: bool | None = None, loss_mask: torch.Tensor | None = None, ) -> torch.Tensor: - """Forward pass that delegates to CustomMCoreMambaModel, which handles loss calculation.""" + """Forward pass that delegates to GPTModel, which handles loss calculation.""" extra_kwargs = {"packed_seq_params": packed_seq_params} if packed_seq_params is not None else {} output_tensor = self.module( input_ids, @@ -147,21 +144,17 @@ def forward( loss_mask=loss_mask, # Pass loss_mask to the Megatron module **extra_kwargs, ) - - # Return whatever CustomMCoreMambaModel.forward returns - # (logits during inference, loss during training) return output_tensor -# Custom MCoreMambaModel with reweighted loss calculation class Evo2StyleMCoreGPTModel(megatron.core.models.gpt.gpt_model.GPTModel): - """Custom version of MCoreMambaModel that implements reweighted loss calculation. + """Custom version of GPTModel that implements reweighted loss calculation. Note that this is similar to the HyenaModel for uppercase/lowercase handling. """ def __init__(self, *args, **kwargs): - """Initializes `Evo2StyleMCoreMambaModel` with unique parameters for the Evo2 variant of `MCoreMambaModel`.""" + """Initializes `Evo2StyleMCoreGPTModel` with unique parameters for the Evo2 variant of `GPTModel`.""" super().__init__(*args, **kwargs) if self.config.use_targeted_variance_loss: if not hasattr(self.config, "embedding_init_method_std"): @@ -205,9 +198,9 @@ def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Te def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False): - """Condition for no weight decay for Mamba parameters. + """Condition for no weight decay for GPT parameters. - Note that this follows the same pattern as in the original Mamba implementation. + Note that this follows the same pattern as in the original GPT implementation. """ # Mamba-specific parameters that should not have weight decay if ("embedding" in name and exclude_embeddings) or getattr(param, "_no_weight_decay", False): @@ -222,16 +215,16 @@ def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False): def gpt_no_weight_decay_cond_with_embeddings(name, param): - """Condition for no weight decay for Mamba parameters with embeddings. + """Condition for no weight decay for GPT parameters with embeddings. - Note that this follows the same pattern as in the original Mamba implementation but also skips WD on embeddings. + Note that this follows the same pattern as in the original GPT implementation but also skips WD on embeddings. """ return gpt_no_weight_decay_cond(name, param, exclude_embeddings=True) @dataclass class LLama31ConfigEvoLoss3B(llm.Llama3Config8B): - """Config for 8B hybrid Mamba model.""" + """Config for 8B hybrid GPT model.""" # RoPE/context length related block: rotary_base: int = 500_000