-
|
Hi JAX team (cc @jburnim ), I'm testing custom Pallas RDMA kernels with TPU Interpret Mode (https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/g3doc/debugging.md#tpu-interpret-mode). My kernels reuse recv semaphores across multiple pipeline stages/algorithm steps (using capacity/ready semaphores to protect against overrun in the circular buffer of recv semaphores). I've noticed different behavior between DMA execution modes:
Looking at jax/_src/pallas/mosaic/interpret.py, I noticed this TODO around line 576: jax/jax/_src/pallas/mosaic/interpret.py Lines 576 to 580 in c850bef Currently, dmas_by_sem is indexed only by semaphore_id, which (IIUC) might cause DMAs from different devices/sections to share the same queue when semaphores are reused. This seems to lead to out-of-order DMA execution in on_wait mode. I'm wondering if that is WAI. Questions:
Using distinct semaphores (i.e. more than the number of steps) eliminates the issues, but this significantly increases semaphore usage. Any guidance would be appreciated! Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
A data race detected by TPU Interpret Mode should never be a false positive. (There are no known issues here, but it is possible there is a bug that is permitting false positives.) I suspect that some later stage/step's RDMA is signaling a semaphore while an earlier stage/step is still waiting for an earlier RDMA to signal the same semaphore, and this is leading to a real race. If a second RDMA is started before an earlier RDMA using the same send/receive semaphores has completed, Pallas permits the second RDMA to to complete and signal the semaphores before the first RDMA. But this will only happen in TPU Interpret Mode with (That TODO is related only in that: (a) |
Beta Was this translation helpful? Give feedback.
A data race detected by TPU Interpret Mode should never be a false positive. (There are no known issues here, but it is possible there is a bug that is permitting false positives.)
I suspect that some later stage/step's RDMA is signaling a semaphore while an earlier stage/step is still waiting for an earlier RDMA to signal the same semaphore, and this is leading to a real race.
If a second RDMA is started before an earlier RDMA using the same send/receive semaphores has completed, Pallas permits the second RDMA to to complete and signal the semaphores before the first RDMA. But this will only happen in TPU Interpret Mode with
dma_execution_mode="on_wait", which is why this kind of race is…