Skip to content
Discussion options

You must be logged in to vote

solved seting xla arg
using #29031

unset XLA_FLAGS
export XLA_FLAGS="--xla_gpu_enable_cublaslt=true \
                  --xla_gpu_cublas_fallback=true \
                  --xla_gpu_enable_command_buffer=''" 
python - <<'PY'
import jax, jax.numpy as jnp
f = jax.jit(lambda x,y: x@y)
a=jnp.ones((256,256), jnp.float32); b=a
f(a,b).block_until_ready()
print("OK: cuBLASLt + fallback")
PY

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by NullXeronier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant