-
Notifications
You must be signed in to change notification settings - Fork 784
Open
Description
Jax's dot product attention supports grouped query attention by allowing the number of key/value heads to be different from the number of query heads. Flax's dot production attention (which differs from Jax's because it allows sowing attention weights into a Module) doesn't currently support this. We should add it.
Example:
import numpy as np
import jax.numpy as jnp
from flax import nnx
B = 2
S = 5
T = 4
N = 6
K = 3
H = 7
query = jnp.array(np.random.randn(B,T,N,H))
key = jnp.array(np.random.randn(B,S,K,H))
value = key
jax.nn.dot_product_attention(query, key, value) # works just fine
nnx.dot_product_attention(query, key, value) # failsMetadata
Metadata
Assignees
Labels
No labels