This repository is dedicated to my exploration of JAX and FLAX, two powerful libraries for high-performance machine learning and deep learning research.
- JAX: A high-performance numerical computing library that enables automatic differentiation, GPU/TPU acceleration, and Just-In-Time (JIT) compilation.
- FLAX: A flexible and extensible neural network library for JAX, designed to support research and production workloads.
- Learn the fundamentals of JAX, including JIT compilation, vectorization (
vmap), and automatic differentiation (grad). - Understand how to build and train neural networks using FLAX.
- Experiment with different architectures and optimization techniques.
- Introduction to JAX: Basics of JAX, including
grad,jit, andvmap.
- JAX-FLAX YouTube Playlist By Aleksa Gordić - The AI Epiphany (I am following this playlist)
- JAX Documentation
- FLAX Documentation
- JAX/FLAX Examples
This project is for educational purposes. Feel free to use and modify the code as needed!
Happy learning with JAX & FLAX! 🚀
