-
Notifications
You must be signed in to change notification settings - Fork 90
Evo2 gpt inference #1106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Evo2 gpt inference #1106
Conversation
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: John St John <[email protected]>
Signed-off-by: Yang Zhang <[email protected]>
WalkthroughAdds an Evo2 GPT integration with new model/config classes and forward helpers; updates Mamba forward for Transformer Engine and packed sequences; wires GPT support into predict/train/infer CLIs and introduces MambaPredictor and dynamic inference sizing; adds RoPE and embedding configuration knobs and weight-decay helpers. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as train.py CLI
participant Builder as Config Builder
participant Options as GPT_MODEL_OPTIONS
participant Model as Evo2GPTModel
participant Core as Evo2StyleMCoreGPTModel
User->>CLI: run train --model-size <size> [GPT/RoPE/emb args]
CLI->>Builder: parse args, detect model_type
alt model_size in GPT options
Builder->>Options: resolve config class
Options-->>Builder: LLama31ConfigEvoLoss3B
Builder->>Builder: compute old_context_len, scale_factor, rotary_base, embedding_init_method_std
Builder->>Model: instantiate with tokenizer and config
Model->>Core: configure_model(...)
Core-->>Model: constructed core GPT model
else other model types
Builder->>Builder: follow existing paths
end
sequenceDiagram
autonumber
participant Pred as predict.py
participant Opts as GPT_MODEL_OPTIONS
participant MPred as MambaPredictor
participant HPred as HyenaPredictor
participant Run as infer.py
Pred->>Pred: parse args model_type
alt model_type == "gpt"
Pred->>Opts: validate model_size
Pred->>HPred: init with GPT-configured settings
else model_type == "mamba"
Pred->>MPred: init MambaPredictor
else invalid
Pred-->>Pred: raise error
end
Run->>Run: compute thresholds = len(prompt) + max_new_tokens
Run-->>Pred: inference executes with dynamic thresholds
sequenceDiagram
autonumber
participant Step as evo2_gpt_forward_step
participant Batch as batch
participant TE as TransformerEngine check
participant GPT as Evo2GPTModel/Evo2StyleMCoreGPTModel
Step->>Batch: read tokens, position_ids, labels, loss_mask
alt attention_mask in batch
Step->>Step: use batch.attention_mask
else
Step->>TE: assert Transformer Engine available
end
opt cu_seqlens in batch
Step->>Step: packed_seq_params = get_packed_seq_params(batch)
end
Step->>GPT: forward(input_ids, position_ids, [attention_mask], labels, loss_mask, [packed_seq_params])
GPT-->>Step: logits or loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)
60-66: Fix grammatical error in assertion message.There's a grammatical error in the assertion message.
- assert HAVE_TE, ( - "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \ - This requires Transformer Engine's implementation of fused or flash attention." - ) + assert HAVE_TE, ( + "The dataloader did not provide an attention mask; however, Transformer Engine was not detected. " + "This requires Transformer Engine's implementation of fused or flash attention." + )sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py (2)
521-628: Consider extract common prediction logic to reduce duplication.The
MambaPredictorclass shares significant logic withHyenaPredictor(e.g., log probability calculations, tensor gathering). Consider extracting the common prediction functionality into a base class or utility functions to improve maintainability.Would you like me to help refactor this to create a shared base class for common predictor functionality?
473-479: Document why HyenaPredictor is used for GPT models.Using
HyenaPredictorfor GPT models might be confusing. Add a comment explaining why this reuse is appropriate, or consider creating a dedicatedGPTPredictorclass if there are GPT-specific requirements.elif model_type == "gpt": + # GPT models can reuse HyenaPredictor since they share the same prediction interface + # and forward step logic. The model configuration handles the architectural differences. model = HyenaPredictor( config, tokenizer=tokenizer, output_log_prob_seqs=output_log_prob_seqs, log_prob_collapse_option=log_prob_collapse_option, )sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (2)
693-706: Improve clarity of old context length determination.The logic for determining
old_context_lenis complex with multiple conditionals. Consider refactoring for clarity and ensure the warning message provides actionable guidance.- if args.ckpt_dir is None or args.old_context_len: - # Set the old context length based on the initial pre-training run seq_length - # when you supply a ckpt_dir, assume that we will use whatever that value was set to previously - # for rope extension. - old_context_len = args.old_context_len or args.seq_length # set to the seq_length if not supplied - else: - if not args.old_context_len: - old_context_len = args.seq_length - logger.warning( - "No old context length supplied, using the seq_length as the old context length. " - "This is not recommended and if training at a different context length the RoPE scaling factors " - "will be incorrect. Please supply the old context length when fine-tuning especially if you are " - "extending the context length." - ) - else: - old_context_len = args.old_context_len + # Determine old context length for RoPE scaling + if args.old_context_len: + old_context_len = args.old_context_len + elif args.ckpt_dir is None: + # New training run: use current seq_length as baseline + old_context_len = args.seq_length + else: + # Fine-tuning from checkpoint without explicit old_context_len + old_context_len = args.seq_length + logger.warning( + f"Fine-tuning from checkpoint without --old-context-len specified. " + f"Defaulting to current seq_length ({args.seq_length}). " + f"If the original model was trained with a different context length, " + f"please specify it using --old-context-len for correct RoPE scaling." + )
685-687: Add informative error for GPT without required flag.The assertion provides a clear requirement but could be more helpful by suggesting the fix directly.
- assert args.no_fp32_residual_connection, ( - "GPT models do not support fp32 residual connection, please run with --no-fp32-residual-connection." - ) + assert args.no_fp32_residual_connection, ( + "GPT models require --no-fp32-residual-connection flag to be set. " + "Please add this flag to your command." + )sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)
380-381: Consider implementing TE full layer spec support.The code raises an error for
use_transformer_engine_full_layer_spec. If this is a planned feature, consider adding a TODO comment or tracking issue reference.Consider updating the error message to be more informative:
if self.use_transformer_engine_full_layer_spec: - raise ValueError("use_transformer_engine_full_layer_spec is not supported in this config.") + raise NotImplementedError( + "use_transformer_engine_full_layer_spec is not yet supported for Evo2 GPT models. " + "Please use the standard transformer layer spec instead." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py(1 hunks)sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py(2 hunks)sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py(1 hunks)sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py(5 hunks)sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (3)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
get_packed_seq_params(227-262)sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss/embedding_variance.py (1)
SquaredErrorTargetedVarianceLoss(169-192)sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (4)
get_inference_wrapper(81-113)forward(116-146)forward(171-275)configure_model(368-395)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
get_packed_seq_params(227-262)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (2)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (2)
Evo2GPTModel(83-153)gpt_no_weight_decay_cond_with_embeddings(224-229)sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (2)
mamba_no_weight_decay_cond_with_embeddings(301-306)MambaModel(74-146)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-bionemo-image
- GitHub Check: Analyze (rust)
🔇 Additional comments (10)
sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py (2)
88-98: LGTM! CLI improvements enhance usability.The change from requiring explicit boolean values to using
store_truemakes the CLI more intuitive and consistent with standard argparse conventions.
174-175: Verify that dynamic threshold calculations don't cause memory issues.The dynamic calculation of
inference_batch_times_seqlen_thresholdandinference_max_seq_lengthbased onlen(prompt) + max_new_tokensis more flexible than the previous fixed value of 8192. However, this could potentially lead to very large memory allocations for long prompts or largemax_new_tokensvalues.Are there any safeguards in the underlying inference wrapper to prevent OOM errors when these thresholds are set too high? Consider adding a maximum cap or warning when the calculated value exceeds a reasonable threshold (e.g., 32768).
sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (1)
68-69: LGTM! Proper handling of packed sequences.The addition of packed sequence support through
get_packed_seq_paramsis well-implemented and follows the established pattern from other parts of the codebase.sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py (1)
630-641: LGTM! Well-structured embedding initialization logic.The handling of embedding initialization parameters is comprehensive with clear precedence rules and appropriate warnings for conflicting options.
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (6)
53-81: LGTM! Clean forward step implementation.The forward step function properly handles batch processing with optional attention masks and packed sequences. The TE requirement check for missing attention masks is appropriate.
163-176: Good implementation of targeted variance loss initialization.The conditional initialization of the targeted variance loss module with proper warning for missing config attributes is well implemented. The use of the Spike No More paper's suggested value (1.0) as default is appropriate.
177-205: Well-structured forward pass with proper loss handling.The forward method correctly handles different execution paths (inference vs training, pipeline parallelism) and properly applies the reweighted cross-entropy loss and optional variance loss only during training.
265-279: Good backward compatibility handling in post-init.The post-init method properly handles backward compatibility for the
embedding_init_method_stdattribute with clear warning messages and sensible defaults.
280-391: Comprehensive model configuration with proper RoPE scaling.The
configure_modelmethod is well-implemented with:
- Proper CUDA graph validation
- Virtual pipeline parallelism handling
- Flexible transformer layer spec resolution
- Correct RoPE scaling application for Llama3.1
393-397: Model registry properly configured.The
GPT_MODEL_OPTIONSdictionary correctly registers the Llama3 8B configuration for use in the model selection system.
| hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general | ||
| forward_step_fn: Callable = evo2_gpt_forward_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Rename the misleadingly named configuration field.
The field hyena_no_weight_decay_cond_fn should be renamed to something more general as indicated by the TODO comment, since this is not specific to Hyena models.
Consider renaming this field to no_weight_decay_cond_fn or weight_decay_cond_fn for clarity:
- hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general
+ no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general | |
| forward_step_fn: Callable = evo2_gpt_forward_step | |
| no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond | |
| forward_step_fn: Callable = evo2_gpt_forward_step |
🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py around lines
252-253, the config field named hyena_no_weight_decay_cond_fn is misleadingly
specific to Hyena per the TODO; rename it to a more general name such as
no_weight_decay_cond_fn (or weight_decay_cond_fn) and update all references and
type annotations accordingly to preserve the same Callable assignment
(gpt_no_weight_decay_cond) and behavior; ensure imports, dataclass/constructor
defaults, serialization keys, and any external usages (attribute access, tests)
are updated to the new name to avoid breaking runtime code.
| elif model_type == "gpt": | ||
| if model_size not in GPT_MODEL_OPTIONS: | ||
| raise ValueError(f"Invalid model size for GPT: {model_size}") | ||
| config = GPT_MODEL_OPTIONS[model_size]( | ||
| forward_step_fn=hyena_predict_forward_step, | ||
| data_step_fn=hyena_predict_data_step, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Invalid model type: {model_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify GPT model reuses Hyena's forward and data step functions correctly.
The GPT model configuration reuses hyena_predict_forward_step and hyena_predict_data_step. While code reuse is good, ensure that these functions are fully compatible with GPT models, particularly regarding attention masks and packed sequences.
🏁 Script executed:
#!/bin/bash
# Verify that GPT models handle the same forward arguments as Hyena models
ast-grep --pattern 'class $GPT_CLASS {
$$$
def forward($$$) {
$$$
}
$$$
}'
# Check if there are any GPT-specific forward step requirements
rg -n "gpt.*forward.*step" --type=pyLength of output: 359
🏁 Script executed:
#!/bin/bash
# Locate hyena_predict_forward_step definition and signature
rg -n -C3 "def hyena_predict_forward_step" -type py
# Locate hyena_predict_data_step definition and signature
rg -n -C3 "def hyena_predict_data_step" -type py
# Preview the GPT forward step implementation for comparison
sed -n '1,100p' sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.pyLength of output: 4204
🏁 Script executed:
#!/bin/bash
# Locate hyena_predict_forward_step definition and signature
rg -n -C3 -t py 'def hyena_predict_forward_step' sub-packages/bionemo-evo2/src
# Locate hyena_predict_data_step definition and signature
rg -n -C3 -t py 'def hyena_predict_data_step' sub-packages/bionemo-evo2/srcLength of output: 1438
Use evo2_gpt_forward_step instead of hyena_predict_forward_step in the GPT config
Hyena’s forward_step unconditionally sets attention_mask=None, ignoring batch masks (docs.nvidia.com).
🤖 Prompt for AI Agents
In sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py around lines 431 to
439, the GPT config currently uses hyena_predict_forward_step which
unconditionally clears attention_mask; replace hyena_predict_forward_step with
evo2_gpt_forward_step when constructing the GPT config (keep data_step_fn as
hyena_predict_data_step unless otherwise required) and ensure
evo2_gpt_forward_step is imported at the top of the file or referenced correctly
so the GPT model respects batch masks.
Signed-off-by: Yang Zhang <[email protected]>
|
@yzhang123 Docs are pointing towards Evo2 mamba when it is Evo2 GPT. What does it mean Evo2GPT? it is GPT with Evo2 type of data input? Can we have some description? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)
244-247: Rename hyena_no_weight_decay_cond_fn to a model-agnostic name.Field name is misleading per the TODO; make it generic.
- hyena_no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_cond # TODO rename to something more general + no_weight_decay_cond_fn: Callable = gpt_no_weight_decay_condRemember to update downstream attribute accesses.
🧹 Nitpick comments (6)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (6)
53-62: Docstring still says “Mamba”; update to GPT.This helper is for GPT, not Mamba.
-def evo2_gpt_forward_step(model, batch) -> torch.Tensor: - """Forward step function for Mamba models, similar to hyena_forward_step. +def evo2_gpt_forward_step(model, batch) -> torch.Tensor: + """Forward step function for GPT models, similar to hyena_forward_step. @@ - model: The Mamba model + model: The GPT model
69-76: Attention mask/TE requirement: verify dataloaders or relax hard assert.Hard-failing when TE isn’t present can break non-TE environments. Either ensure all dataloaders always provide attention_mask without TE, or gate with a clearer runtime error + fallback.
Proposed tweak:
- if "attention_mask" not in batch: - assert HAVE_TE, ( - "The dataloader did not provide an attention mask, however Transformer Engine was not detected. \ - This requires Transformer Engine's implementation of fused or flash attention." - ) + if "attention_mask" not in batch: + if not HAVE_TE: + raise RuntimeError( + "attention_mask missing and Transformer Engine not available. " + "Provide attention_mask via dataloader or enable TE (fused/flash attention)." + )
200-214: Comment still references “Mamba-specific parameters”.This isn’t Mamba-only.
- # Mamba-specific parameters that should not have weight decay + # Parameters that should not have weight decay
225-233: Class name vs size mismatch (3B vs 8B).Class is named LLama31ConfigEvoLoss3B but inherits Llama3Config8B and uses 8B shapes; docstring also says 8B. Rename for clarity and consistency.
-@dataclass -class LLama31ConfigEvoLoss3B(llm.Llama3Config8B): - """Config for 8B hybrid GPT model.""" +@dataclass +class Llama31ConfigEvoLoss8B(llm.Llama3Config8B): + """Config for Llama 3.1 8B GPT model with Evo2-style loss."""Update references:
-GPT_MODEL_OPTIONS: dict[str, type[Llama3Config]] = { - "llama3_8b": LLama31ConfigEvoLoss3B, -} +GPT_MODEL_OPTIONS: dict[str, type[Llama3Config]] = { + "llama3_8b": Llama31ConfigEvoLoss8B, +}
371-383: Full TE layer not supported path: consider a clearer error message.Raise includes no remediation hints; add guidance (e.g., set use_transformer_engine_full_layer_spec=False) to shorten iteration time.
- raise ValueError("use_transformer_engine_full_layer_spec is not supported in this config.") + raise ValueError( + "use_transformer_engine_full_layer_spec is not supported in this config. " + "Set use_transformer_engine_full_layer_spec=False or select a compatible config." + )
50-50: Unused variable _grad_accum_fusion_available.It’s defined but not used; drop it or wire it into feature gating.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (3)
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py (1)
get_packed_seq_params(227-262)sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss/embedding_variance.py (1)
SquaredErrorTargetedVarianceLoss(169-192)sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py (3)
get_inference_wrapper(81-113)forward(116-146)forward(171-275)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Analyze (rust)
🔇 Additional comments (1)
sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (1)
170-197: Move uppercase transform before loss computation and verify loss shape
- Apply
make_upper_case(and capturelowercase_mask) before callingsuper().forwardso the loss is computed against the transformed labels.- Manually confirm that
Evo2StyleMCoreGPTModel.forwardreturns an unreduced per-token loss tensor; if it returns a scalar,reweighted_cross_entropywill misbehave.@override def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, **kwargs): - _forward_out = super().forward(*args, labels=labels, loss_mask=loss_mask, **kwargs) + lowercase_mask = None + if labels is not None: + labels, lowercase_mask = make_upper_case(labels) + _forward_out = super().forward(*args, labels=labels, loss_mask=loss_mask, **kwargs)
| model_init_device_context = contextlib.nullcontext | ||
| if self.init_model_with_meta_device: | ||
| model_init_device_context = partial(torch.device, device="meta") | ||
|
|
||
| if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters: | ||
| kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)} | ||
| else: | ||
| kwargs = {} | ||
| with model_init_device_context(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Meta-device init context is incorrect; torch.device is not a context manager.
with model_init_device_context(): will fail when using partial(torch.device, device="meta"). Use torch.set_default_device("meta") which is a context manager in PyTorch 2.x.
- model_init_device_context = contextlib.nullcontext
- if self.init_model_with_meta_device:
- model_init_device_context = partial(torch.device, device="meta")
+ model_init_device_context = contextlib.nullcontext
+ if self.init_model_with_meta_device:
+ # Set default device to 'meta' during construction.
+ model_init_device_context = partial(torch.set_default_device, "meta")If your minimum PyTorch version doesn’t support set_default_device, consider accelerate.init_empty_weights() or Megatron’s meta-init utility instead.
🏁 Script executed:
#!/bin/bash
# Locate project requirement and configuration files
fd pyproject.toml setup.py requirements.txt environment.yml -t f
# Search for PyTorch version specifications
grep -R "torch" -n pyproject.toml setup.py requirements.txt environment.yml || trueLength of output: 549
Use torch.set_default_device instead of torch.device as context manager
In sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py (lines 338–346), replace:
- model_init_device_context = contextlib.nullcontext
- if self.init_model_with_meta_device:
- model_init_device_context = partial(torch.device, device="meta")
+ model_init_device_context = contextlib.nullcontext
+ if self.init_model_with_meta_device:
+ model_init_device_context = partial(torch.set_default_device, "meta")torch.device isn’t a context manager; PyTorch 2.3.* supports torch.set_default_device.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model_init_device_context = contextlib.nullcontext | |
| if self.init_model_with_meta_device: | |
| model_init_device_context = partial(torch.device, device="meta") | |
| if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters: | |
| kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)} | |
| else: | |
| kwargs = {} | |
| with model_init_device_context(): | |
| model_init_device_context = contextlib.nullcontext | |
| if self.init_model_with_meta_device: | |
| model_init_device_context = partial(torch.set_default_device, "meta") | |
| if "mtp_block_spec" in inspect.signature(Evo2StyleMCoreGPTModel.__init__).parameters: | |
| kwargs = {"mtp_block_spec": mtp_block_spec(self, vp_stage=vp_stage)} | |
| else: | |
| kwargs = {} | |
| with model_init_device_context(): |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1106 +/- ##
==========================================
- Coverage 80.62% 79.78% -0.84%
==========================================
Files 157 158 +1
Lines 11079 11283 +204
==========================================
+ Hits 8932 9002 +70
- Misses 2147 2281 +134
|
Description
make evo2 gpt model compatible with predict.py and infer.py
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels:
Note
By default, the notebooks validation tests are skipped unless explicitly enabled.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Usage
# TODO: Add code snippetPre-submit Checklist
Summary by CodeRabbit
New Features
Chores
Bug Fixes