-
Notifications
You must be signed in to change notification settings - Fork 954
feat: Add LoRA-based knowledge distillation for Wan-Video #922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add LoRA-based knowledge distillation for Wan-Video #922
Conversation
This commit introduces a new training script and a shell script to perform LoRA-based knowledge distillation for Wan-Video models. The new script `examples/wanvideo/model_training/train_distill.py` is a modified version of the standard training script. It is designed to load both a student and a teacher model. The student model is trained with a LoRA adapter, and the loss function is a combination of the standard diffusion loss and a distillation loss. The distillation loss is the mean squared error between the student's and the teacher's predictions. A new shell script `examples/wanvideo/model_training/lora/Wan-Distill.sh` is provided to launch the distillation process with the correct arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @MaulanaYusufIkhsanRobbani, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the Wan-Video model training capabilities by integrating LoRA-based knowledge distillation. The primary goal is to enable the creation of more compact and performant video generation models by distilling knowledge from larger, pre-trained teacher models into smaller student models, thereby improving efficiency without significant performance degradation.
Highlights
- LoRA-based Knowledge Distillation: Introduces a new training script,
train_distill.py
, specifically designed for performing LoRA-based knowledge distillation on Wan-Video models. This allows for training smaller, more efficient student models by leveraging the knowledge of larger teacher models. - Combined Loss Function: The new training script utilizes a combined loss function for the student model. This loss is a weighted sum of the standard diffusion loss and a distillation loss, which is calculated as the mean squared error between the student's and the teacher's predictions.
- Dedicated Launch Script: A new shell script,
Wan-Distill.sh
, has been added to simplify the execution of the LoRA-based knowledge distillation process. It pre-configures the necessary arguments for launchingtrain_distill.py
.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands on the current page.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces knowledge distillation for Wan-Video models using LoRA. The changes include a new training script and a shell script to run it. My review focuses on improving the correctness, robustness, and maintainability of the new training script.
I've identified a few critical issues that would prevent the training from running correctly on a GPU. Specifically, the trainable models are not correctly registered, and there's a device mismatch for the timestep tensor. I've also provided several medium-severity suggestions to improve code quality, such as removing duplicated code, using a proper logger, and making the script more flexible. Addressing these points will make the new training feature more robust and easier to maintain.
if model_id_with_origin_paths is not None: | ||
model_id_with_origin_paths = model_id_with_origin_paths.split(",") | ||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] | ||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models within self.pipe
(e.g., dit
, vae
) are not registered as nn.Module
attributes of WanTrainingModule
. Consequently, they won't be moved to the correct device by accelerate
, and their parameters won't be collected for training by trainable_modules()
. This will likely result in training running on the CPU with no trainable parameters, or device-mismatch errors.
To fix this, you should register the necessary models from the pipeline as attributes of the WanTrainingModule
right after creating the pipeline. This will ensure they are correctly handled by PyTorch and Accelerate.
Example fix:
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
self.dit = self.pipe.dit
self.vae = self.pipe.vae
self.text_encoder = self.pipe.text_encoder
self.image_encoder = self.pipe.image_encoder
self.motion_controller = self.pipe.motion_controller
self.vace = self.pipe.vace
if hasattr(self.pipe, 'dit2') and self.pipe.dit2 is not None:
self.dit2 = self.pipe.dit2
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) | ||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16, device=self.pipe.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The timestep
tensor is being moved to self.pipe.device
, which is hardcoded to 'cpu'
. When training with accelerate
on a GPU, this will cause a device mismatch error because other tensors in the operation are on the GPU. You should create the timestep_id
tensor on the correct device from the start.
Note: This fix assumes that the trainable models are correctly registered as attributes of this module, as suggested in another comment. Otherwise, next(self.parameters())
will fail.
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) | |
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16, device=self.pipe.device) | |
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,), device=next(self.parameters()).device) | |
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=torch.bfloat16) |
@@ -0,0 +1,13 @@ | |||
accelerate launch examples/wanvideo/model_training/train_distill.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For robustness, it's a good practice to add set -e
at the beginning of your shell scripts. This will cause the script to exit immediately if a command exits with a non-zero status, preventing unexpected behavior from errors in the script.
accelerate launch examples/wanvideo/model_training/train_distill.py \ | |
set -e | |
accelerate launch examples/wanvideo/model_training/train_distill.py \ |
model_configs = [] | ||
if model_paths is not None: | ||
model_paths = json.loads(model_paths) | ||
model_configs += [ModelConfig(path=path) for path in model_paths] | ||
if model_id_with_origin_paths is not None: | ||
model_id_with_origin_paths = model_id_with_origin_paths.split(",") | ||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] | ||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) | ||
|
||
# Load teacher model | ||
teacher_model_configs = [] | ||
if teacher_model_paths is not None: | ||
teacher_model_paths = json.loads(teacher_model_paths) | ||
teacher_model_configs += [ModelConfig(path=path) for path in teacher_model_paths] | ||
if teacher_model_id_with_origin_paths is not None: | ||
teacher_model_id_with_origin_paths = teacher_model_id_with_origin_paths.split(",") | ||
teacher_model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in teacher_model_id_with_origin_paths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for parsing model configurations is duplicated for the student and teacher models. To improve maintainability and reduce redundancy, you could extract this logic into a helper method. For example:
def _load_model_configs(model_paths, model_id_with_origin_paths):
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs.extend([ModelConfig(path=path) for path in model_paths])
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs.extend([
ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1])
for i in model_id_with_origin_paths
])
return model_configs
You can then call this helper for both student and teacher models.
if teacher_model_id_with_origin_paths is not None: | ||
teacher_model_id_with_origin_paths = teacher_model_id_with_origin_paths.split(",") | ||
teacher_model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in teacher_model_id_with_origin_paths] | ||
self.teacher_pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cuda", model_configs=teacher_model_configs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device for the teacher pipeline is hardcoded to 'cuda'
. This can be problematic in multi-GPU setups where accelerate
might assign a different device (e.g., cuda:1
) for training. It would be more flexible to make the teacher's device configurable via a command-line argument, for example --teacher_device
.
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") | ||
if len(load_result[1]) > 0: | ||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using print()
for logging is generally discouraged. It's better to use the logging
module, which provides more flexibility, such as logging levels, formatting, and routing output to files or other handlers. This is especially useful for debugging and monitoring training processes. You should import logging
at the top of the file.
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") | |
if len(load_result[1]) > 0: | |
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") | |
import logging | |
logging.info(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") | |
if len(load_result[1]) > 0: | |
logging.warning(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") |
# Assume you are using this pipeline for inference, | ||
# please fill in the input parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MaulanaYusufIkhsanRobbani Thank you for providing the code. Recently, we have been researching how to build a more general training framework that encapsulates SFT, distillation, and DPO into a unified module, making it orthogonal to the model architecture. Therefore, we will merge this functionality at a later time. |
This commit introduces a new training script and a shell script to perform LoRA-based knowledge distillation for Wan-Video models.
The new script
examples/wanvideo/model_training/train_distill.py
is a modified version of the standard training script. It is designed to load both a student and a teacher model. The student model is trained with a LoRA adapter, and the loss function is a combination of the standard diffusion loss and a distillation loss. The distillation loss is the mean squared error between the student's and the teacher's predictions.A new shell script
examples/wanvideo/model_training/lora/Wan-Distill.sh
is provided to launch the distillation process with the correct arguments.