Commit e10ebf0
authored
Fix Prefix Finetuning for Group Query Attention (GQA) (#825)
Resolves: #819
## Problem
PrefixTuning currently fails on modern architectures with
**Grouped-Query Attention (GQA)** (e.g., Llama 3.1), raising shape
mismatches in the attention forward pass.
Issues:
- Assumes `num_attention_heads` = KV heads (invalid for GQA).
- Computes per-head dim by dividing `hidden_size` by KV heads, which
only works for standard MHA.
## Solution
This PR updates `PrefixTuningLayer.add_adapter` to properly support GQA
and similar mechanisms:
- **Correct head count:** Prefer `config.num_key_value_heads` when
available, fallback to `num_attention_heads`.
- **Robust per-head dim:**
- Use `config.d_kv` if defined (e.g., T5).
- Else compute as `hidden_size // num_attention_heads`.
This ensures prefix tensors align with internal KV states across **MHA,
GQA, and MQA**, fixing Llama 3.1 while preserving compatibility with
existing models.1 parent 6e9a467 commit e10ebf0
1 file changed
+7
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
362 | 362 | | |
363 | 363 | | |
364 | 364 | | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
365 | 370 | | |
366 | 371 | | |
367 | 372 | | |
368 | | - | |
| 373 | + | |
369 | 374 | | |
370 | | - | |
| 375 | + | |
371 | 376 | | |
372 | 377 | | |
373 | 378 | | |
| |||
0 commit comments