-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix examples not loading LoRA adapter weights from checkpoint #12690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi! Tagging @sayakpaul and @yiyixuxu for review since this PR modifies a training example. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR!
Could we follow this
| def save_model_hook(models, weights, output_dir): |
for saving hooks and this
| def load_model_hook(models, input_dir): |
for loading?
|
Great idea to update the accelerate hooks to save model. |
|
@sayakpaul |
|
Hello! |
| ) | ||
|
|
||
| # Make sure the trainable params are in float32 | ||
| if args.mixed_precision in ["fp16", "bf16"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to upcast when using Bfloat16.
| convert_unet_state_dict_to_peft | ||
| except ImportError: | ||
| try: | ||
| from diffusers.loaders.peft import convert_unet_state_dict_to_peft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's available from diffusers.utils.
|
Hello @sayakpaul !!! |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
This PR fixes an issue where LoRA adapter weights are not loaded when resuming training from a checkpoint in the training example
train_text_to_image_lora.py.Currently, when using
--resume_from_checkpoint, the script restores the UNet state, optimizer, and scheduler states — but does not restorepytorch_lora_weights.safetensors, even though the checkpoint folder contains it.This results in resumed training not applying the previously trained LoRA weights, causing quality degradation and training inconsistency.
This PR adds the missing logic to:
pytorch_lora_weights.safetensorsin the checkpoint foldersafetensors.torch.load_file()unwrap_model(unet).load_state_dict(..., strict=False)This makes checkpoint resume behavior consistent and correct.
Fixes
Fixes #12689
Before submitting
--resume_from_checkpoint.Who can review?