This package provides a set of benchmark scripts that can be used to profile JAX performance on a varying number of CPU cores. JAX does not provide control over the number of cores it uses, so a common trick is to work do this with taskset.
The benchmarks can be run by installing the package with pip and running it as follows:
python3 -m pip install git+https://github.com/ComPWA/jax-mini-benchmark@main
benchmark-jaxThe resulting benchmark can be viewed in jax-benchmark-$HOSTNAME.svg. If you do not want to view the resulting plot directly, like when you run this command in a script, add the --no-show flag:
benchmark-jax --no-showWe recommend working with a virtual environment (more info here). If you have installed Miniconda, the project can easily be set up as follows:
git clone https://github.com/ComPWA/jax-mini-benchmark
cd jax-mini-benchmark
conda env create
conda activate jax-mini-benchmark
pre-commit install # optional, but recommendedSee ComPWA's Help developing for more info.