Skip to content

Support QGA in nnx dot_product_attention #5177

@samanklesaria

Description

@samanklesaria

Jax's dot product attention supports grouped query attention by allowing the number of key/value heads to be different from the number of query heads. Flax's dot production attention (which differs from Jax's because it allows sowing attention weights into a Module) doesn't currently support this. We should add it.

Example:

import numpy as np
import jax.numpy as jnp
from flax import nnx

B = 2
S = 5
T = 4
N = 6
K = 3
H = 7

query = jnp.array(np.random.randn(B,T,N,H))
key = jnp.array(np.random.randn(B,S,K,H))
value = key

jax.nn.dot_product_attention(query, key, value) # works just fine

nnx.dot_product_attention(query, key, value) # fails

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions