You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The benchmark results show that a single invocation of the scan function to take about 15 microseconds. For the function I had (GRU), this is relatively high, as a hand-crafed function using intrinsics in native C++ achieves 3 microseconds, and an ONNX function with the same scan function design achieves 8 microseconds.
I suspect that one of the inefficiencies may be that PJRT_LoadedExecutable_Execute necessarily allocates new buffers for every invocation, along with overhead from the threading model, though I cannot prove this.
I am also happy for the scan function's new carry return variable to mutate the input carry variable in-place. However, I do not know if this is possible. After rewriting scan_fn to take Refs, it seems that jax.export.export does not support this.
How to inference this scan function faster?
Is there intuition for how the runtimes may scale as the computation gets more expensive? (these figures are with a GRU with very few parameters). Maybe this method just has a high baseline cost but scales better to larger computations and data?
Are there alternatives that may be better suited for this task?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I am using the PJRT C API to inference a scan function.
My recipe for doing so involves:
pjrt_c_api_cpu_plugin.sofrom XLA repo.The benchmark results show that a single invocation of the scan function to take about 15 microseconds. For the function I had (GRU), this is relatively high, as a hand-crafed function using intrinsics in native C++ achieves 3 microseconds, and an ONNX function with the same scan function design achieves 8 microseconds.
I suspect that one of the inefficiencies may be that
PJRT_LoadedExecutable_Executenecessarily allocates new buffers for every invocation, along with overhead from the threading model, though I cannot prove this.I am also happy for the scan function's new carry return variable to mutate the input carry variable in-place. However, I do not know if this is possible. After rewriting
scan_fnto takeRefs, it seems thatjax.export.exportdoes not support this.Beta Was this translation helpful? Give feedback.
All reactions