Skip to content

A learning repository for JAX and FLAX, exploring automatic differentiation, JIT compilation, neural networks, and optimization techniques. Includes tutorials, notebooks, and hands-on experiments.

Notifications You must be signed in to change notification settings

imdebamrita/JAX-FLAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 

Repository files navigation

JAX-FLAX Learning Repository

JAX Logo

Visitors

This repository is dedicated to my exploration of JAX and FLAX, two powerful libraries for high-performance machine learning and deep learning research.

Overview

  • JAX: A high-performance numerical computing library that enables automatic differentiation, GPU/TPU acceleration, and Just-In-Time (JIT) compilation.
  • FLAX: A flexible and extensible neural network library for JAX, designed to support research and production workloads.

Goals

  • Learn the fundamentals of JAX, including JIT compilation, vectorization (vmap), and automatic differentiation (grad).
  • Understand how to build and train neural networks using FLAX.
  • Experiment with different architectures and optimization techniques.

Getting Started

Notebooks

  • Introduction to JAX: Basics of JAX, including grad, jit, and vmap.

Resources

License

This project is for educational purposes. Feel free to use and modify the code as needed!


Happy learning with JAX & FLAX! 🚀

About

A learning repository for JAX and FLAX, exploring automatic differentiation, JIT compilation, neural networks, and optimization techniques. Includes tutorials, notebooks, and hands-on experiments.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published