diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index 3a416af..3d11ec2 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -75,13 +75,13 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), - alpha=alpha, + scale=alpha, target_replace_module=["CLIPAttention"], ) @@ -93,12 +93,12 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), - alpha=alpha, + scale=alpha, target_replace_module=["CLIPAttention"], ) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 7672b48..a04bee3 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -10,24 +10,39 @@ class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4): + def __init__(self, in_features, out_features, bias=False, r=4, scale=1.0, init=None, nonlin: nn.Module = None): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) - + + if scale <= 0: + raise ValueError( + f"LoRA scale {scale} must be greater than 0" + ) + + self.r = r + self.scale = scale self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) + self.nonlin = nonlin if nonlin else None self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = 1.0 - nn.init.normal_(self.lora_down.weight, std=1 / r) + if init=="kaiming": + pass + # Kaiming with a=math.sqrt(5) is default for nn.Linear + else: + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale + if self.nonlin: + return self.linear(input) + self.lora_up(self.nonlin(self.lora_down(input))) * self.scale + else: + return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} @@ -116,6 +131,9 @@ def inject_trainable_lora( model: nn.Module, target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, + scale: float = 1.0, + init=None, + nonlin=None, loras=None, # path to lora .pt ): """ @@ -137,7 +155,10 @@ def inject_trainable_lora( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, - r, + r=r, + scale=scale, + init=init, + nonlin=nonlin, ) _tmp.linear.weight = weight if bias is not None: @@ -333,9 +354,13 @@ def load_safeloras(path, device="cpu"): def weight_apply_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0 -): - + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + scale: float = 1.0, + nonlin: nn.Module = None, +): for _m, _n, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] ): @@ -344,13 +369,22 @@ def weight_apply_lora( up_weight = loras.pop(0).detach().to(weight.device) down_weight = loras.pop(0).detach().to(weight.device) - # W <- W + U * D - weight = weight + alpha * (up_weight @ down_weight).type(weight.dtype) + if nonlin is None: + # W <- W + U * D + weight = weight + scale * (up_weight @ down_weight).type(weight.dtype) + else: + # W <- W + U * nonlin(D) + weight = weight + scale * (up_weight @ nonlin(down_weight)).type(weight.dtype) + _child_module.weight = nn.Parameter(weight) - def monkeypatch_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + scale: float = 1.0, + nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] @@ -362,6 +396,8 @@ def monkeypatch_lora( _child_module.out_features, _child_module.bias is not None, r=r, + scale=scale, + nonlin=nonlin, ) _tmp.linear.weight = weight @@ -385,7 +421,12 @@ def monkeypatch_lora( def monkeypatch_replace_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + scale: float = 1.0, + nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[LoraInjectedLinear] @@ -397,7 +438,9 @@ def monkeypatch_replace_lora( _child_module.linear.out_features, _child_module.linear.bias is not None, r=r, - ) + scale=scale, + nonlin=nonlin, + ) _tmp.linear.weight = weight if bias is not None: @@ -424,6 +467,8 @@ def monkeypatch_or_replace_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: Union[int, List[int]] = 4, + scale: Union[float, List[float]] = 1.0, + nonlin: Union[float, List[nn.Module]] = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] @@ -441,6 +486,8 @@ def monkeypatch_or_replace_lora( _source.out_features, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, + scale=scale.pop(0) if isinstance(scale, list) else scale, + nonlin=nonlin.pop(0) if isinstance(nonlin, list) else nonlin, ) _tmp.linear.weight = weight @@ -496,7 +543,7 @@ def monkeypatch_add_lora( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, - alpha: float = 1.0, + scale: float = 1.0, beta: float = 1.0, ): for _module, name, _child_module in _find_modules( @@ -519,12 +566,16 @@ def monkeypatch_add_lora( _module._modules[name].to(weight.device) -def tune_lora_scale(model, alpha: float = 1.0): +def tune_lora_scale(model, alpha: float = 1.0, scale: float = None): + if alpha: + # Keep original named parameter alpha (which is really scale), + scale = alpha + for _module in model.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": - _module.scale = alpha - + _module.scale = scale + def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) @@ -576,6 +627,8 @@ def patch_pipe( unet_path, token: str, r: int = 4, + scale: float = 1.0, + nonlin: nn.Module = None, patch_unet=True, patch_text=False, patch_ti=False, @@ -596,6 +649,8 @@ def patch_pipe( pipe.unet, torch.load(unet_path), r=r, + scale=scale, + nonlin=nonlin, target_replace_module=unet_target_replace_module, ) @@ -606,6 +661,8 @@ def patch_pipe( torch.load(text_path), target_replace_module=text_target_replace_module, r=r, + scale=scale, + nonlin=nonlin, ) if patch_ti: print("LoRA : Patching token input")