Skip to content
Open
29 changes: 29 additions & 0 deletions examples/diffusion/finetune/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.diffusion.finetune import TrainWan21DiffusionRecipe


def main(default_config_path="examples/diffusion/finetune/wan2_1_t2v_flow.yaml"):
cfg = parse_args_and_load_config(default_config_path)
recipe = TrainWan21DiffusionRecipe(cfg)
recipe.setup()
recipe.run_train_validation_loop()


if __name__ == "__main__":
main()
61 changes: 61 additions & 0 deletions examples/diffusion/finetune/wan2_1_t2v_flow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
seed: 42

wandb:
project: wan-t2v-flow-matching
mode: online
name: wan2_1_t2v_fm_updated

dist_env:
backend: nccl
timeout_minutes: 30

model:
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_wan21_dataloader
meta_folder: /lustre/fsw/portfolios/coreai/users/linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta/
batch_size: 1
num_workers: 2
device: cpu
num_nodes: 1

batch:
batch_size_per_node: 8

training:
num_epochs: 20

optim:
learning_rate: 5e-6
optimizer:
weight_decay: 0.01
betas: [0.9, 0.999]

flow_matching:
use_sigma_noise: true
timestep_sampling: uniform
logit_mean: 0.0
logit_std: 1.0
flow_shift: 3.0
mix_uniform_ratio: 0.1

fsdp:
cpu_offload: true
tp_size: 1
cp_size: 1
pp_size: 1

logging:
save_every: 50
log_every: 2

checkpoint:
enabled: true
checkpoint_dir: /opt/Automodel/wan_t2v_flow_outputs_base_recipe_checkpoint_NEW_new/
model_save_format: torch_save
save_consolidated: false
restore_from: null


31 changes: 29 additions & 2 deletions nemo_automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def _move_module_to_device(module: nn.Module, device: torch.device, torch_dtype:
module.to(device=device)


def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = None) -> int:
"""
Ensure that all parameters in the given module are trainable.

Returns the number of parameters marked trainable. If a module name is
provided, it will be used in the log message for clarity.
"""
num_trainable_parameters = 0
for parameter in module.parameters():
parameter.requires_grad = True
num_trainable_parameters += parameter.numel()
if module_name is None:
module_name = module.__class__.__name__
logger.info("[Trainable] %s: %s parameters set requires_grad=True", module_name, f"{num_trainable_parameters:,}")
return num_trainable_parameters


class NeMoAutoDiffusionPipeline(DiffusionPipeline):
"""
Drop-in Diffusers pipeline that adds optional FSDP2/TP parallelization during from_pretrained.
Expand All @@ -90,6 +107,8 @@ def from_pretrained(
device: Optional[torch.device] = None,
torch_dtype: Any = "auto",
move_to_device: bool = True,
load_for_training: bool = False,
components_to_load: Optional[Iterable[str]] = None,
**kwargs,
) -> DiffusionPipeline:
pipe: DiffusionPipeline = DiffusionPipeline.from_pretrained(
Expand All @@ -98,14 +117,22 @@ def from_pretrained(
torch_dtype=torch_dtype,
**kwargs,
)

# Decide device
dev = _choose_device(device)

# Move modules to device/dtype first (helps avoid initial OOM during sharding)
if move_to_device:
for name, module in _iter_pipeline_modules(pipe):
_move_module_to_device(module, dev, torch_dtype)
if not components_to_load or name in components_to_load:
logger.info("[INFO] Moving module: %s to device/dtype", name)
_move_module_to_device(module, dev, torch_dtype)

# If loading for training, ensure the target module parameters are trainable
if load_for_training:
for name, module in _iter_pipeline_modules(pipe):
if not components_to_load or name in components_to_load:
logger.info("[INFO] Ensuring params trainable: %s", name)
_ensure_params_trainable(module, module_name=name)

# Use per-component FSDP2Manager mappings to parallelize components
if parallel_scheme is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

import numpy as np
import torch


def time_shift(
t: torch.Tensor,
image_seq_len: int,
shift_type: str = "constant",
base_shift: float = 0.5,
max_shift: float = 1.15,
constant: float = 3.0,
):
"""
Convert timesteps to sigmas with sequence-length-aware shifting.

Args:
t: timesteps in range [0, 1]
image_seq_len: number of tokens (frames * height * width / patch_size^2)
shift_type: "linear", "sqrt", or "constant"
base_shift: base shift for linear mode
max_shift: max shift for linear mode
constant: shift value for constant mode (default 3.0 matches Pika)

Returns:
sigma values for noise scheduling
"""
if shift_type == "linear":
# Linear interpolation based on sequence length
mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096)
return math.exp(mu) / (math.exp(mu) + (1 / t - 1))

elif shift_type == "sqrt":
# Square root scaling (Flux-style)
# Assuming 128x128 latent space (1024x1024 image) gives mu=3
mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0)
return mu / (mu + (1 / t - 1))

elif shift_type == "constant":
# Constant shift (Pika default)
return constant / (constant + (1 / t - 1))

else:
# No shift, return original t
return t


def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mode_scale: float = 1.29,
):
"""
Sample timesteps from different distributions for better training coverage.

Args:
weighting_scheme: "uniform", "logit_normal", or "mode"
batch_size: number of samples to generate
logit_mean: mean for logit-normal distribution
logit_std: std for logit-normal distribution
mode_scale: scale for mode-based sampling

Returns:
Tensor of shape (batch_size,) with values in [0, 1]
"""
if weighting_scheme == "logit_normal":
# SD3-style logit-normal sampling
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)

elif weighting_scheme == "mode":
# Mode-based sampling (concentrates around certain timesteps)
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)

else:
# Uniform sampling (default)
u = torch.rand(size=(batch_size,), device="cpu")

return u


def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0):
"""
Compute loss weights for flow matching based on sigma values.

Higher sigma (more noise) typically gets higher weight.

Args:
sigma: sigma values in range [0, 1]
shift: weight scaling factor

Returns:
Loss weights with same shape as sigma
"""
# Flow matching weight: weight = 1 + shift * sigma
# This gives more weight to noisier timesteps
weight = 1.0 + shift * sigma
return weight
Loading
Loading