1717import warnings
1818from dataclasses import asdict , dataclass , field
1919from enum import Enum
20- from typing import List , Optional , Union
20+ from typing import List , Optional , Tuple , Union
2121
2222import torch
2323import torch .nn as nn
@@ -262,6 +262,12 @@ def _create_new_module(self, lora_config, adapter_name, target):
262262 embedding_kwargs .pop ("fan_in_fan_out" , None )
263263 in_features , out_features = target .num_embeddings , target .embedding_dim
264264 new_module = Embedding (adapter_name , in_features , out_features , ** embedding_kwargs )
265+ elif isinstance (target , torch .nn .Conv2d ):
266+ out_channels , in_channels = target .weight .size ()[:2 ]
267+ kernel_size = target .weight .size ()[2 :]
268+ stride = target .stride
269+ padding = target .padding
270+ new_module = Conv2d (adapter_name , in_channels , out_channels , kernel_size , stride , padding , ** kwargs )
265271 else :
266272 if isinstance (target , torch .nn .Linear ):
267273 in_features , out_features = target .in_features , target .out_features
@@ -303,7 +309,15 @@ def _find_and_replace(self, adapter_name):
303309 is_target_modules_in_base_model = True
304310 parent , target , target_name = _get_submodules (self .model , key )
305311
306- if isinstance (target , LoraLayer ):
312+ if isinstance (target , LoraLayer ) and isinstance (target , torch .nn .Conv2d ):
313+ target .update_layer_conv2d (
314+ adapter_name ,
315+ lora_config .r ,
316+ lora_config .lora_alpha ,
317+ lora_config .lora_dropout ,
318+ lora_config .init_lora_weights ,
319+ )
320+ elif isinstance (target , LoraLayer ):
307321 target .update_layer (
308322 adapter_name ,
309323 lora_config .r ,
@@ -489,11 +503,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
489503
490504
491505class LoraLayer :
492- def __init__ (
493- self ,
494- in_features : int ,
495- out_features : int ,
496- ):
506+ def __init__ (self , in_features : int , out_features : int , ** kwargs ):
497507 self .r = {}
498508 self .lora_alpha = {}
499509 self .scaling = {}
@@ -508,6 +518,7 @@ def __init__(
508518 self .disable_adapters = False
509519 self .in_features = in_features
510520 self .out_features = out_features
521+ self .kwargs = kwargs
511522
512523 def update_layer (self , adapter_name , r , lora_alpha , lora_dropout , init_lora_weights ):
513524 self .r [adapter_name ] = r
@@ -527,6 +538,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
527538 self .reset_lora_parameters (adapter_name )
528539 self .to (self .weight .device )
529540
541+ def update_layer_conv2d (self , adapter_name , r , lora_alpha , lora_dropout , init_lora_weights ):
542+ self .r [adapter_name ] = r
543+ self .lora_alpha [adapter_name ] = lora_alpha
544+ if lora_dropout > 0.0 :
545+ lora_dropout_layer = nn .Dropout (p = lora_dropout )
546+ else :
547+ lora_dropout_layer = nn .Identity ()
548+
549+ self .lora_dropout .update (nn .ModuleDict ({adapter_name : lora_dropout_layer }))
550+ # Actual trainable parameters
551+ if r > 0 :
552+ kernel_size = self .kwargs ["kernel_size" ]
553+ stride = self .kwargs ["stride" ]
554+ padding = self .kwargs ["padding" ]
555+ self .lora_A .update (
556+ nn .ModuleDict ({adapter_name : nn .Conv2d (self .in_features , r , kernel_size , stride , padding , bias = False )})
557+ )
558+ self .lora_B .update (
559+ nn .ModuleDict ({adapter_name : nn .Conv2d (r , self .out_features , (1 , 1 ), (1 , 1 ), bias = False )})
560+ )
561+ self .scaling [adapter_name ] = lora_alpha / r
562+ if init_lora_weights :
563+ self .reset_lora_parameters (adapter_name )
564+ self .to (self .weight .device )
565+
530566 def update_layer_embedding (self , adapter_name , r , lora_alpha , lora_dropout , init_lora_weights ):
531567 self .r [adapter_name ] = r
532568 self .lora_alpha [adapter_name ] = lora_alpha
@@ -728,6 +764,148 @@ def forward(self, x: torch.Tensor):
728764 return nn .Embedding .forward (self , x )
729765
730766
767+ class Conv2d (nn .Conv2d , LoraLayer ):
768+ # Lora implemented in a conv2d layer
769+ def __init__ (
770+ self ,
771+ adapter_name : str ,
772+ in_channels : int ,
773+ out_channels : int ,
774+ kernel_size : Union [int , Tuple [int ]],
775+ stride : Union [int , Tuple [int ]] = 1 ,
776+ padding : Union [int , Tuple [int ]] = 0 ,
777+ r : int = 0 ,
778+ lora_alpha : int = 1 ,
779+ lora_dropout : float = 0.0 ,
780+ ** kwargs ,
781+ ):
782+ init_lora_weights = kwargs .pop ("init_lora_weights" , True )
783+
784+ nn .Conv2d .__init__ (self , in_channels , out_channels , kernel_size , stride , padding )
785+ LoraLayer .__init__ (
786+ self ,
787+ in_features = in_channels ,
788+ out_features = out_channels ,
789+ kernel_size = kernel_size ,
790+ stride = stride ,
791+ padding = padding ,
792+ )
793+ # Freezing the pre-trained weight matrix
794+ self .weight .requires_grad = False
795+
796+ nn .Conv2d .reset_parameters (self )
797+ self .update_layer_conv2d (adapter_name , r , lora_alpha , lora_dropout , init_lora_weights )
798+ self .active_adapter = adapter_name
799+
800+ def merge (self ):
801+ if self .active_adapter not in self .lora_A .keys ():
802+ return
803+ if self .merged :
804+ warnings .warn ("Already merged. Nothing to do." )
805+ return
806+ if self .r [self .active_adapter ] > 0 :
807+ # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
808+ if self .weight .size ()[2 :4 ] == (1 , 1 ):
809+ # conv2d 1x1
810+ self .weight .data += (
811+ self .lora_B [self .active_adapter ].weight .squeeze (3 ).squeeze (2 )
812+ @ self .lora_A [self .active_adapter ].weight .squeeze (3 ).squeeze (2 )
813+ ).unsqueeze (2 ).unsqueeze (3 ) * self .scaling [self .active_adapter ]
814+ else :
815+ # conv2d 3x3
816+ self .weight .data += (
817+ F .conv2d (
818+ self .lora_A [self .active_adapter ].weight .permute (1 , 0 , 2 , 3 ),
819+ self .lora_B [self .active_adapter ].weight ,
820+ ).permute (1 , 0 , 2 , 3 )
821+ * self .scaling [self .active_adapter ]
822+ )
823+ self .merged = True
824+
825+ def unmerge (self ):
826+ if self .active_adapter not in self .lora_A .keys ():
827+ return
828+ if not self .merged :
829+ warnings .warn ("Already unmerged. Nothing to do." )
830+ return
831+ if self .r [self .active_adapter ] > 0 :
832+ if self .weight .size ()[2 :4 ] == (1 , 1 ):
833+ # conv2d 1x1
834+ self .weight .data -= (
835+ self .lora_B [self .active_adapter ].weight .squeeze (3 ).squeeze (2 )
836+ @ self .lora_A [self .active_adapter ].weight .squeeze (3 ).squeeze (2 )
837+ ).unsqueeze (2 ).unsqueeze (3 ) * self .scaling [self .active_adapter ]
838+ else :
839+ # conv2d 3x3
840+ self .weight .data += (
841+ F .conv2d (
842+ self .lora_A [self .active_adapter ].weight .permute (1 , 0 , 2 , 3 ),
843+ self .lora_B [self .active_adapter ].weight ,
844+ ).permute (1 , 0 , 2 , 3 )
845+ * self .scaling [self .active_adapter ]
846+ )
847+ self .merged = False
848+
849+ def forward (self , x : torch .Tensor ):
850+ previous_dtype = x .dtype
851+
852+ if self .active_adapter not in self .lora_A .keys ():
853+ return F .conv2d (
854+ x ,
855+ self .weight ,
856+ bias = self .bias ,
857+ stride = self .stride ,
858+ padding = self .padding ,
859+ dilation = self .dilation ,
860+ groups = self .groups ,
861+ )
862+ if self .disable_adapters :
863+ if self .r [self .active_adapter ] > 0 and self .merged :
864+ self .unmerge ()
865+ result = F .conv2d (
866+ x ,
867+ self .weight ,
868+ bias = self .bias ,
869+ stride = self .stride ,
870+ padding = self .padding ,
871+ dilation = self .dilation ,
872+ groups = self .groups ,
873+ )
874+ elif self .r [self .active_adapter ] > 0 and not self .merged :
875+ result = F .conv2d (
876+ x ,
877+ self .weight ,
878+ bias = self .bias ,
879+ stride = self .stride ,
880+ padding = self .padding ,
881+ dilation = self .dilation ,
882+ groups = self .groups ,
883+ )
884+
885+ x = x .to (self .lora_A [self .active_adapter ].weight .dtype )
886+
887+ result += (
888+ self .lora_B [self .active_adapter ](
889+ self .lora_A [self .active_adapter ](self .lora_dropout [self .active_adapter ](x ))
890+ )
891+ * self .scaling [self .active_adapter ]
892+ )
893+ else :
894+ result = F .conv2d (
895+ x ,
896+ self .weight ,
897+ bias = self .bias ,
898+ stride = self .stride ,
899+ padding = self .padding ,
900+ dilation = self .dilation ,
901+ groups = self .groups ,
902+ )
903+
904+ result = result .to (previous_dtype )
905+
906+ return result
907+
908+
731909if is_bnb_available ():
732910
733911 class Linear8bitLt (bnb .nn .Linear8bitLt , LoraLayer ):
0 commit comments