Skip to content

Conversation

MaulanaYusufIkhsanRobbani

This commit introduces a new training script and a shell script to perform LoRA-based knowledge distillation for Wan-Video models.

The new script examples/wanvideo/model_training/train_distill.py is a modified version of the standard training script. It is designed to load both a student and a teacher model. The student model is trained with a LoRA adapter, and the loss function is a combination of the standard diffusion loss and a distillation loss. The distillation loss is the mean squared error between the student's and the teacher's predictions.

A new shell script examples/wanvideo/model_training/lora/Wan-Distill.sh is provided to launch the distillation process with the correct arguments.

This commit introduces a new training script and a shell script to perform LoRA-based knowledge distillation for Wan-Video models.

The new script `examples/wanvideo/model_training/train_distill.py` is a modified version of the standard training script. It is designed to load both a student and a teacher model. The student model is trained with a LoRA adapter, and the loss function is a combination of the standard diffusion loss and a distillation loss. The distillation loss is the mean squared error between the student's and the teacher's predictions.

A new shell script `examples/wanvideo/model_training/lora/Wan-Distill.sh` is provided to launch the distillation process with the correct arguments.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @MaulanaYusufIkhsanRobbani, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the Wan-Video model training capabilities by integrating LoRA-based knowledge distillation. The primary goal is to enable the creation of more compact and performant video generation models by distilling knowledge from larger, pre-trained teacher models into smaller student models, thereby improving efficiency without significant performance degradation.

Highlights

  • LoRA-based Knowledge Distillation: Introduces a new training script, train_distill.py, specifically designed for performing LoRA-based knowledge distillation on Wan-Video models. This allows for training smaller, more efficient student models by leveraging the knowledge of larger teacher models.
  • Combined Loss Function: The new training script utilizes a combined loss function for the student model. This loss is a weighted sum of the standard diffusion loss and a distillation loss, which is calculated as the mean squared error between the student's and the teacher's predictions.
  • Dedicated Launch Script: A new shell script, Wan-Distill.sh, has been added to simplify the execution of the LoRA-based knowledge distillation process. It pre-configures the necessary arguments for launching train_distill.py.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces knowledge distillation for Wan-Video models using LoRA. The changes include a new training script and a shell script to run it. My review focuses on improving the correctness, robustness, and maintainability of the new training script.

I've identified a few critical issues that would prevent the training from running correctly on a GPU. Specifically, the trainable models are not correctly registered, and there's a device mismatch for the timestep tensor. I've also provided several medium-severity suggestions to improve code quality, such as removing duplicated code, using a proper logger, and making the script more flexible. Addressing these points will make the new training feature more robust and easier to maintain.

if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The models within self.pipe (e.g., dit, vae) are not registered as nn.Module attributes of WanTrainingModule. Consequently, they won't be moved to the correct device by accelerate, and their parameters won't be collected for training by trainable_modules(). This will likely result in training running on the CPU with no trainable parameters, or device-mismatch errors.

To fix this, you should register the necessary models from the pipeline as attributes of the WanTrainingModule right after creating the pipeline. This will ensure they are correctly handled by PyTorch and Accelerate.

Example fix:

        self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
        self.dit = self.pipe.dit
        self.vae = self.pipe.vae
        self.text_encoder = self.pipe.text_encoder
        self.image_encoder = self.pipe.image_encoder
        self.motion_controller = self.pipe.motion_controller
        self.vace = self.pipe.vace
        if hasattr(self.pipe, 'dit2') and self.pipe.dit2 is not None:
            self.dit2 = self.pipe.dit2

Comment on lines +125 to +126
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16, device=self.pipe.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The timestep tensor is being moved to self.pipe.device, which is hardcoded to 'cpu'. When training with accelerate on a GPU, this will cause a device mismatch error because other tensors in the operation are on the GPU. You should create the timestep_id tensor on the correct device from the start.

Note: This fix assumes that the trainable models are correctly registered as attributes of this module, as suggested in another comment. Otherwise, next(self.parameters()) will fail.

Suggested change
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16, device=self.pipe.device)
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,), device=next(self.parameters()).device)
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16)

@@ -0,0 +1,13 @@
accelerate launch examples/wanvideo/model_training/train_distill.py \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness, it's a good practice to add set -e at the beginning of your shell scripts. This will cause the script to exit immediately if a command exits with a non-zero status, preventing unexpected behavior from errors in the script.

Suggested change
accelerate launch examples/wanvideo/model_training/train_distill.py \
set -e
accelerate launch examples/wanvideo/model_training/train_distill.py \

Comment on lines +25 to +41
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)

# Load teacher model
teacher_model_configs = []
if teacher_model_paths is not None:
teacher_model_paths = json.loads(teacher_model_paths)
teacher_model_configs += [ModelConfig(path=path) for path in teacher_model_paths]
if teacher_model_id_with_origin_paths is not None:
teacher_model_id_with_origin_paths = teacher_model_id_with_origin_paths.split(",")
teacher_model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in teacher_model_id_with_origin_paths]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for parsing model configurations is duplicated for the student and teacher models. To improve maintainability and reduce redundancy, you could extract this logic into a helper method. For example:

def _load_model_configs(model_paths, model_id_with_origin_paths):
    model_configs = []
    if model_paths is not None:
        model_paths = json.loads(model_paths)
        model_configs.extend([ModelConfig(path=path) for path in model_paths])
    if model_id_with_origin_paths is not None:
        model_id_with_origin_paths = model_id_with_origin_paths.split(",")
        model_configs.extend([
            ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1])
            for i in model_id_with_origin_paths
        ])
    return model_configs

You can then call this helper for both student and teacher models.

if teacher_model_id_with_origin_paths is not None:
teacher_model_id_with_origin_paths = teacher_model_id_with_origin_paths.split(",")
teacher_model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in teacher_model_id_with_origin_paths]
self.teacher_pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cuda", model_configs=teacher_model_configs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device for the teacher pipeline is hardcoded to 'cuda'. This can be problematic in multi-GPU setups where accelerate might assign a different device (e.g., cuda:1) for training. It would be more flexible to make the teacher's device configurable via a command-line argument, for example --teacher_device.

Comment on lines +64 to +66
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print() for logging is generally discouraged. It's better to use the logging module, which provides more flexibility, such as logging levels, formatting, and routing output to files or other handlers. This is especially useful for debugging and monitoring training processes. You should import logging at the top of the file.

Suggested change
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
import logging
logging.info(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
logging.warning(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")

Comment on lines +85 to +86
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment seems to be copied from an inference script and is misleading in a training context. It should be removed to avoid confusion.

@Artiprocher
Copy link
Collaborator

@MaulanaYusufIkhsanRobbani Thank you for providing the code. Recently, we have been researching how to build a more general training framework that encapsulates SFT, distillation, and DPO into a unified module, making it orthogonal to the model architecture. Therefore, we will merge this functionality at a later time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants