Skip to content

Commit e10ebf0

Browse files
authored
Fix Prefix Finetuning for Group Query Attention (GQA) (#825)
Resolves: #819 ## Problem PrefixTuning currently fails on modern architectures with **Grouped-Query Attention (GQA)** (e.g., Llama 3.1), raising shape mismatches in the attention forward pass. Issues: - Assumes `num_attention_heads` = KV heads (invalid for GQA). - Computes per-head dim by dividing `hidden_size` by KV heads, which only works for standard MHA. ## Solution This PR updates `PrefixTuningLayer.add_adapter` to properly support GQA and similar mechanisms: - **Correct head count:** Prefer `config.num_key_value_heads` when available, fallback to `num_attention_heads`. - **Robust per-head dim:** - Use `config.d_kv` if defined (e.g., T5). - Else compute as `hidden_size // num_attention_heads`. This ensures prefix tensors align with internal KV states across **MHA, GQA, and MQA**, fixing Llama 3.1 while preserving compatibility with existing models.
1 parent 6e9a467 commit e10ebf0

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/adapters/methods/prefix_tuning.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,17 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
362362
location_key=used_location_key,
363363
)
364364
if prefix_tuning_config is not None:
365+
num_kv_heads = getattr(self.model_config, "num_key_value_heads", self.model_config.num_attention_heads)
366+
head_dim = getattr(self.model_config, "d_kv", None)
367+
368+
if head_dim is None:
369+
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
365370
prefix_id = self.pool.indicate_prefix(
366371
adapter_name,
367372
self.location_key,
368-
n_heads=self.model_config.num_attention_heads,
373+
n_heads=num_kv_heads,
369374
input_size=self.model_config.hidden_size,
370-
n_embd_per_head=getattr(self.model_config, "d_kv", None), # this is currently specific to T5-3B
375+
n_embd_per_head=head_dim,
371376
)
372377
self.prefixes[adapter_name] = prefix_id
373378

0 commit comments

Comments
 (0)