-
Notifications
You must be signed in to change notification settings - Fork 16
Description
Has anybody investigated the possibility to allow for an array agnostic way to leverage the torch.compile and jax.jit decorators in array-api-extra?
This might be useful for array API consuming libraries such as SciPy or scikit-learn. For array API namespaces without JIT compiler support, xpx.compile would just result in a noop decorator. For torch and JAX it might, dispatching to an actual JIT compiler could unlock significant speed-ups and memory usage improvements.
However, the parameters of those decorators have many kwargs with seemingly very little overlap:
- https://docs.jax.dev/en/latest/_autosummary/jax.jit.html
- https://docs.pytorch.org/docs/stable/generated/torch.compile.html
Maybe xpx.compile could be made to accept arbitrary kwargs scoped by the underlying namespace name without attempting to map common compiler semantics together.
@xpx.compile(
torch=dict(options={"triton.cudagraphs": True}, fullgraph=True),
jax=dict(static_argnames=['n']),
)
def some_array_function(array, n):
...I have little experience to tell whether calling those decorators with their default argument is useful or not in practice.