Profiling large jax scripts #32581
Replies: 1 comment
-
Hi, I have previously tried couple profilers documented on the website. As you said, for large scripts, it might get very hard to digest. Although it is not the most trivial one, I found out that the NVidia Nsight profiler works well for my use case (on NVidia GPUs). There are couple ways, you can add custom annotations to a function or a portion of the script using import nvtx
# as decorator
@nvtx.annotate("my_func range", color="red")
def my_func():
do_something()
# or context manager
with nvtx.annotate("for_loop", color="green"):
do_something_else()
# or
rng = nvtx.start_range(message="my_message", color="blue")
# ... do something ... #
nvtx.end_range(rng) There are a number of traces you can add optionally (there might be slowdown due to extra traces), you specify them when you call your script on the command line via
The official documentation has an exhaustive list of options. This command will create an Note: this can also be used by MPI jobs, and you can merge multiple reports into one. As noted in Jax docs, jitted functions are opaque to any type of profiler. So, you cannot add an annotation to a subset of jitted function, but you can still see what the hardware is doing (resource usage, memory transfers, device copies etc). Usually, jitted functions show up in the XLA traces with their name, but that part gets messy really quick since internal Jax functions are also jitted. If you have a rough idea of which part of the code is the bottleneck, you can start putting more I had shared multiple Nsight screenshots in #29470 before. I would also be interested in other people's experiences with other profilers for large scripts. Hope this helps! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Thank you maintainers for this awesome project.
I have developed my own library of reinforcement learning agents in Jax, it works incredibly well and allows to run the experiments end-to-end on GPU. However, due to my rather messy understand of Jax inner-workings I am sure I am still underperforming compared to what Jax could achieve. As a result, I would like to profile my whole code to identify the performance bottlenecks.
For a python program I would use cProfile, for jax programs however I am a bit puzzled. I have tried the https://docs.jax.dev/en/latest/profiling.html tutorial, however I run into several issues. It seems that tensorboard crashes for large traces, so I rescoped my analysis to only analyze a sub-function:
I do get a trace for this, however I do not seem to understand what to do with it, or if it worked in the first place:
Is this the expected output? If so, how should I proceed to identify some bottlenecks from there? It seems that it's mostly jax functions, and not my functions, is it not possible to see those?
I should also specify that, for debugging/profiling, I am fully on CPU. Should I switch to some computer where I have a GPU to properly profile? Sorry if the question is too broad, overall I would really love to learn the best practices to write and profile jax programs.
If the jax profiler is not the right tool for this, could someone maybe point me towards the proper tool for this type of analysis?
PS : I have skimmed through similar discussions on this topic, however most seemed to run into the same problems I do without getting closure, so I really think having some insight on this subject would benefit the community.
Beta Was this translation helpful? Give feedback.
All reactions