@@ -178,7 +178,6 @@ def __init__(
178178 self ,
179179 # ↓ this part is for pretrained weights
180180 in_features : int ,
181- out_features : int ,
182181 # ↓ the remaining part is for LoRA
183182 head_size : int ,
184183 n_head : int ,
@@ -199,7 +198,6 @@ def __init__(
199198
200199 Args:
201200 in_features: number of input features of the pretrained weights
202- out_features: number of output features of the pretrained weights
203201 head_size: size of a single attention head
204202 n_head: number of attention heads
205203 n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
@@ -214,6 +212,7 @@ def __init__(
214212 and `value` but keep `key` without weight updates we should pass `[True, False, True]`
215213 """
216214 super (LoRALinear , self ).__init__ (r = r , lora_alpha = lora_alpha , lora_dropout = lora_dropout )
215+ out_features = head_size * (n_head + 2 * n_query_groups )
217216 self .linear = torch .nn .Linear (in_features , out_features , ** kwargs )
218217 self .head_size = head_size
219218 self .n_head = n_head
@@ -229,18 +228,19 @@ def __init__(
229228 # ⚬ out_features: 384 (3 * embedding_size)
230229 # ⚬ r: 2
231230 # ⚬ enable_lora: [True, False, True]
231+ self ._all_qkv_shapes = (
232+ # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
233+ # might not be equal to `head_size * n_head`, thus we use it directly here
234+ head_size * n_head ,
235+ head_size * n_query_groups ,
236+ head_size * n_query_groups ,
237+ )
232238 if r > 0 and any (enable_lora ):
233239 self .lora_A = nn .Parameter (torch .empty ((r * sum (enable_lora ), in_features ))) # (4, 128)
234- enable_q , enable_k , enable_v = enable_lora
235240 # qkv_shapes will be used to split a tensor with weights correctly
236- qkv_shapes = (
237- # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
238- # might not be equal to `head_size * n_head`, thus we use it directly here
239- head_size * n_head * enable_q ,
240- head_size * n_query_groups * enable_k ,
241- head_size * n_query_groups * enable_v ,
242- )
243- self .qkv_shapes = [s for s in qkv_shapes if s ]
241+ self .qkv_shapes = [
242+ s for s , e in zip (self ._all_qkv_shapes , enable_lora ) if e
243+ ]
244244 self .lora_B = nn .Parameter (torch .empty (sum (self .qkv_shapes ), r )) # (256, 2))
245245 # Notes about shapes above
246246 # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
@@ -266,15 +266,13 @@ def lora_ind(self) -> torch.Tensor:
266266 """Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
267267 # Indices are needed to properly pad weight updates with zeros.
268268 if not hasattr (self , "_lora_ind" ):
269- enable_q , enable_k , enable_v = self .enable_lora
270- kv_embd_size = self .linear .in_features // (self .n_head // self .n_query_groups )
269+ off = 0
271270 lora_ind = []
272- if enable_q :
273- lora_ind .extend (range (0 , self .linear .in_features ))
274- if enable_k :
275- lora_ind .extend (range (self .linear .in_features , self .linear .in_features + kv_embd_size ))
276- if enable_v :
277- lora_ind .extend (range (self .linear .in_features + kv_embd_size , self .linear .out_features ))
271+ for enable , size in zip (self .enable_lora , self ._all_qkv_shapes ):
272+ if enable :
273+ lora_ind .extend (range (off , off + size ))
274+ off += size
275+ assert len (lora_ind ) == sum (self .qkv_shapes ) # Sanity check
278276 self .register_buffer (
279277 "_lora_ind" , torch .tensor (lora_ind , device = self .linear .weight .device ), persistent = False
280278 )
@@ -527,10 +525,8 @@ class CausalSelfAttention(BaseCausalSelfAttention):
527525 def __init__ (self , config : Config , block_idx : int ) -> None :
528526 super ().__init__ (config , block_idx )
529527 # key, query, value projections for all heads, but in a batch
530- shape = (config .n_head + 2 * config .n_query_groups ) * config .head_size
531528 self .qkv = LoRAQKVLinear (
532529 in_features = config .n_embd ,
533- out_features = shape ,
534530 r = config .lora_r ,
535531 lora_alpha = config .lora_alpha ,
536532 lora_dropout = config .lora_dropout ,
0 commit comments