A JAX port of FLUX.1 models using flax.nnx.
Important
The current codebase is designed to maintain consistency with the original implementation, with minimal modifications. While it works as expected, it may not be the most efficient implementation. I plan to release an updated version soon that better adheres to JAX conventions and best practices.
Only tested with GPU now.
Currently no quantization support & no torch-like CPU offloading support.
PRs are welcome.
git clone https://github.com/lkwq007/flux-flax.git
cd flux-flax
mamba create -p ./env python=3.10
mamba activate ./env
pip install -r requirements.txtFor interactive sampling run
python main.py --name <name>Or to generate a single sample run (not recommended, as jit compilation takes time)
python main.py --name <name> \
--height <height> --width <width> --nonloop \
--prompt "<prompt>"