A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
-
Updated
Mar 4, 2025 - Python
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
A FlashAttention backwards-over-backwards ⚡🔙🔙
Calculate the hash of any input for ZK-Friendly hashes (MiMC & Poseidon) over a variety of Elliptic Curves.
dLLM training implementation on pure jax/flax (w/o pytorch) for Google TPUs(v4/v5e/v6e). #TPUSprint #TRC
Benchmarking the JAX Pallas implementation of a custom RNN against alternatives
Packet-Switched Attention for stable 2-bit quantized MoE inference, with variance-aware routing and Protocol C benchmarks.
Flash Attention from first principles on TPU using JAX Pallas.
Repo to hold core components when building a Pallas Systems Website
术 (Shu) — The first GPU-accelerated MSM for the Pallas curve. Part of the HanFei 韩非 series.
Lean 4 formalization of the Pasta curves (Pallas and Vesta) for Zcash's Halo 2 — primality proofs and IsElliptic instances
SuperNova (Pasta) proof generator & verifier with CI and frozen fixtures
Add a description, image, and links to the pallas topic page so that developers can more easily learn about it.
To associate your repository with the pallas topic, visit your repo's landing page and select "manage topics."