Skip to content

Commit 9320373

Browse files
authored
LoRA for Conv2d layer, script to convert kohya_ss LoRA to PEFT (#461)
* Added LoRA for Conv2d layer, script to convert kohya_ss linear lora to PEFT * Fixed code style, added missing safetensors dependency for kohya_ss to peft conversion script
1 parent 019b7ff commit 9320373

File tree

3 files changed

+328
-8
lines changed

3 files changed

+328
-8
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import argparse
2+
import os
3+
from typing import List, Optional
4+
5+
import safetensors
6+
import torch
7+
import torch.nn as nn
8+
from diffusers import UNet2DConditionModel
9+
from transformers import CLIPTextModel
10+
11+
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
12+
13+
14+
# Default kohya_ss LoRA replacement modules
15+
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L661
16+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
17+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
18+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
19+
LORA_PREFIX_UNET = "lora_unet"
20+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
21+
22+
23+
def get_modules_names(
24+
root_module: nn.Module,
25+
target_replace_modules_linear: Optional[List[str]] = [],
26+
target_replace_modules_conv2d: Optional[List[str]] = [],
27+
):
28+
# Combine replacement modules
29+
target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d
30+
31+
# Store result
32+
modules_names = set()
33+
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720
34+
for name, module in root_module.named_modules():
35+
if module.__class__.__name__ in target_replace_modules:
36+
if len(name) == 0:
37+
continue
38+
for child_name, child_module in module.named_modules():
39+
if len(child_name) == 0:
40+
continue
41+
is_linear = child_module.__class__.__name__ == "Linear"
42+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
43+
44+
if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or (
45+
is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d
46+
):
47+
modules_names.add(f"{name}.{child_name}")
48+
49+
return sorted(modules_names)
50+
51+
52+
if __name__ == "__main__":
53+
parser = argparse.ArgumentParser()
54+
55+
parser.add_argument("--sd_checkpoint", default=None, type=str, required=True, help="SD checkpoint to use")
56+
57+
parser.add_argument(
58+
"--kohya_lora_path", default=None, type=str, required=True, help="Path to kohya_ss trained LoRA"
59+
)
60+
61+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
62+
63+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
64+
args = parser.parse_args()
65+
66+
# Find text encoder modules to add LoRA to
67+
text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder")
68+
text_encoder_modules_names = get_modules_names(
69+
text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE
70+
)
71+
72+
# Find unet2d modules to add LoRA to
73+
unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet")
74+
unet_modules_names = get_modules_names(
75+
unet,
76+
target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE,
77+
target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE,
78+
)
79+
80+
# Open kohya_ss checkpoint
81+
with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f:
82+
# Extract information about LoRA structure
83+
metadata = f.metadata()
84+
lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"])
85+
lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"])
86+
87+
# Create LoRA for text encoder
88+
text_encoder_config = LoraConfig(
89+
r=lora_text_encoder_r,
90+
lora_alpha=lora_text_encoder_alpha,
91+
target_modules=text_encoder_modules_names,
92+
lora_dropout=0.0,
93+
bias="none",
94+
)
95+
text_encoder = get_peft_model(text_encoder, text_encoder_config)
96+
text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()}
97+
98+
# Load text encoder values from kohya_ss LoRA
99+
for peft_te_key in text_encoder_lora_state_dict.keys():
100+
kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER)
101+
kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down")
102+
kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up")
103+
kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2)
104+
text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype)
105+
106+
# Load converted kohya_ss text encoder LoRA back to PEFT
107+
set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict)
108+
109+
if args.half:
110+
text_encoder.to(torch.float16)
111+
112+
# Save text encoder result
113+
text_encoder.save_pretrained(
114+
os.path.join(args.dump_path, "text_encoder"),
115+
)
116+
117+
# Create LoRA for unet2d
118+
unet_config = LoraConfig(
119+
r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none"
120+
)
121+
unet = get_peft_model(unet, unet_config)
122+
unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()}
123+
124+
# Load unet2d values from kohya_ss LoRA
125+
for peft_unet_key in unet_lora_state_dict.keys():
126+
kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET)
127+
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down")
128+
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up")
129+
kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2)
130+
unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype)
131+
132+
# Load converted kohya_ss unet LoRA back to PEFT
133+
set_peft_model_state_dict(unet, unet_lora_state_dict)
134+
135+
if args.half:
136+
unet.to(torch.float16)
137+
138+
# Save text encoder result
139+
unet.save_pretrained(
140+
os.path.join(args.dump_path, "unet"),
141+
)

examples/lora_dreambooth/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ datasets
77
diffusers
88
Pillow
99
torchvision
10-
huggingface_hub
10+
huggingface_hub
11+
safetensors

src/peft/tuners/lora.py

Lines changed: 185 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from dataclasses import asdict, dataclass, field
1919
from enum import Enum
20-
from typing import List, Optional, Union
20+
from typing import List, Optional, Tuple, Union
2121

2222
import torch
2323
import torch.nn as nn
@@ -262,6 +262,12 @@ def _create_new_module(self, lora_config, adapter_name, target):
262262
embedding_kwargs.pop("fan_in_fan_out", None)
263263
in_features, out_features = target.num_embeddings, target.embedding_dim
264264
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
265+
elif isinstance(target, torch.nn.Conv2d):
266+
out_channels, in_channels = target.weight.size()[:2]
267+
kernel_size = target.weight.size()[2:]
268+
stride = target.stride
269+
padding = target.padding
270+
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
265271
else:
266272
if isinstance(target, torch.nn.Linear):
267273
in_features, out_features = target.in_features, target.out_features
@@ -303,7 +309,15 @@ def _find_and_replace(self, adapter_name):
303309
is_target_modules_in_base_model = True
304310
parent, target, target_name = _get_submodules(self.model, key)
305311

306-
if isinstance(target, LoraLayer):
312+
if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
313+
target.update_layer_conv2d(
314+
adapter_name,
315+
lora_config.r,
316+
lora_config.lora_alpha,
317+
lora_config.lora_dropout,
318+
lora_config.init_lora_weights,
319+
)
320+
elif isinstance(target, LoraLayer):
307321
target.update_layer(
308322
adapter_name,
309323
lora_config.r,
@@ -489,11 +503,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
489503

490504

491505
class LoraLayer:
492-
def __init__(
493-
self,
494-
in_features: int,
495-
out_features: int,
496-
):
506+
def __init__(self, in_features: int, out_features: int, **kwargs):
497507
self.r = {}
498508
self.lora_alpha = {}
499509
self.scaling = {}
@@ -508,6 +518,7 @@ def __init__(
508518
self.disable_adapters = False
509519
self.in_features = in_features
510520
self.out_features = out_features
521+
self.kwargs = kwargs
511522

512523
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
513524
self.r[adapter_name] = r
@@ -527,6 +538,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
527538
self.reset_lora_parameters(adapter_name)
528539
self.to(self.weight.device)
529540

541+
def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
542+
self.r[adapter_name] = r
543+
self.lora_alpha[adapter_name] = lora_alpha
544+
if lora_dropout > 0.0:
545+
lora_dropout_layer = nn.Dropout(p=lora_dropout)
546+
else:
547+
lora_dropout_layer = nn.Identity()
548+
549+
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
550+
# Actual trainable parameters
551+
if r > 0:
552+
kernel_size = self.kwargs["kernel_size"]
553+
stride = self.kwargs["stride"]
554+
padding = self.kwargs["padding"]
555+
self.lora_A.update(
556+
nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)})
557+
)
558+
self.lora_B.update(
559+
nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)})
560+
)
561+
self.scaling[adapter_name] = lora_alpha / r
562+
if init_lora_weights:
563+
self.reset_lora_parameters(adapter_name)
564+
self.to(self.weight.device)
565+
530566
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
531567
self.r[adapter_name] = r
532568
self.lora_alpha[adapter_name] = lora_alpha
@@ -728,6 +764,148 @@ def forward(self, x: torch.Tensor):
728764
return nn.Embedding.forward(self, x)
729765

730766

767+
class Conv2d(nn.Conv2d, LoraLayer):
768+
# Lora implemented in a conv2d layer
769+
def __init__(
770+
self,
771+
adapter_name: str,
772+
in_channels: int,
773+
out_channels: int,
774+
kernel_size: Union[int, Tuple[int]],
775+
stride: Union[int, Tuple[int]] = 1,
776+
padding: Union[int, Tuple[int]] = 0,
777+
r: int = 0,
778+
lora_alpha: int = 1,
779+
lora_dropout: float = 0.0,
780+
**kwargs,
781+
):
782+
init_lora_weights = kwargs.pop("init_lora_weights", True)
783+
784+
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding)
785+
LoraLayer.__init__(
786+
self,
787+
in_features=in_channels,
788+
out_features=out_channels,
789+
kernel_size=kernel_size,
790+
stride=stride,
791+
padding=padding,
792+
)
793+
# Freezing the pre-trained weight matrix
794+
self.weight.requires_grad = False
795+
796+
nn.Conv2d.reset_parameters(self)
797+
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
798+
self.active_adapter = adapter_name
799+
800+
def merge(self):
801+
if self.active_adapter not in self.lora_A.keys():
802+
return
803+
if self.merged:
804+
warnings.warn("Already merged. Nothing to do.")
805+
return
806+
if self.r[self.active_adapter] > 0:
807+
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
808+
if self.weight.size()[2:4] == (1, 1):
809+
# conv2d 1x1
810+
self.weight.data += (
811+
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
812+
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
813+
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
814+
else:
815+
# conv2d 3x3
816+
self.weight.data += (
817+
F.conv2d(
818+
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
819+
self.lora_B[self.active_adapter].weight,
820+
).permute(1, 0, 2, 3)
821+
* self.scaling[self.active_adapter]
822+
)
823+
self.merged = True
824+
825+
def unmerge(self):
826+
if self.active_adapter not in self.lora_A.keys():
827+
return
828+
if not self.merged:
829+
warnings.warn("Already unmerged. Nothing to do.")
830+
return
831+
if self.r[self.active_adapter] > 0:
832+
if self.weight.size()[2:4] == (1, 1):
833+
# conv2d 1x1
834+
self.weight.data -= (
835+
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
836+
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
837+
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
838+
else:
839+
# conv2d 3x3
840+
self.weight.data += (
841+
F.conv2d(
842+
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
843+
self.lora_B[self.active_adapter].weight,
844+
).permute(1, 0, 2, 3)
845+
* self.scaling[self.active_adapter]
846+
)
847+
self.merged = False
848+
849+
def forward(self, x: torch.Tensor):
850+
previous_dtype = x.dtype
851+
852+
if self.active_adapter not in self.lora_A.keys():
853+
return F.conv2d(
854+
x,
855+
self.weight,
856+
bias=self.bias,
857+
stride=self.stride,
858+
padding=self.padding,
859+
dilation=self.dilation,
860+
groups=self.groups,
861+
)
862+
if self.disable_adapters:
863+
if self.r[self.active_adapter] > 0 and self.merged:
864+
self.unmerge()
865+
result = F.conv2d(
866+
x,
867+
self.weight,
868+
bias=self.bias,
869+
stride=self.stride,
870+
padding=self.padding,
871+
dilation=self.dilation,
872+
groups=self.groups,
873+
)
874+
elif self.r[self.active_adapter] > 0 and not self.merged:
875+
result = F.conv2d(
876+
x,
877+
self.weight,
878+
bias=self.bias,
879+
stride=self.stride,
880+
padding=self.padding,
881+
dilation=self.dilation,
882+
groups=self.groups,
883+
)
884+
885+
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
886+
887+
result += (
888+
self.lora_B[self.active_adapter](
889+
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
890+
)
891+
* self.scaling[self.active_adapter]
892+
)
893+
else:
894+
result = F.conv2d(
895+
x,
896+
self.weight,
897+
bias=self.bias,
898+
stride=self.stride,
899+
padding=self.padding,
900+
dilation=self.dilation,
901+
groups=self.groups,
902+
)
903+
904+
result = result.to(previous_dtype)
905+
906+
return result
907+
908+
731909
if is_bnb_available():
732910

733911
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):

0 commit comments

Comments
 (0)