-
Notifications
You must be signed in to change notification settings - Fork 61
Open
Labels
Milestone
Description
What happened?
In trying to debug something, I encountered this weird edge case where the result of parallel matrix multiplication is wrong. The error is detailed below and appears only when run with two tasks. I am not sure what makes this edge case special.
I have been running into this issue first with QR decomposition. You can try:
import heat as ht
comm = ht.comm
A = ht.random.randn(comm.size, comm.size, dtype=ht.float32, split=0)
QR = ht.linalg.qr(A)
matmul_success_loc = ht.allclose(QR.Q @ QR.R, A)
matmul_success_glob = ht.allclose((QR.Q.resplit(None) @ QR.R.resplit(None)).resplit(QR.Q.split), A)
print(matmul_success_loc, matmul_success_glob)
which will print False, True.
I am a bit lost here. In particular, the matrix multiplication function is very long and hard to understand. So I thought I just report this bug because it seems pretty bad and maybe somebody can give me some pointers.
Code snippet triggering the error
import heat as ht
split = 0
shape = (4, 3)
A = ht.ones(shape, split=split)
B = ht.ones(shape[::-1], split=split)
C = A @ B
C_glob = (A.resplit(None) @ B.resplit(None)).resplit(C.split)
print(C)
print(C_glob)Error message or erroneous outcome
Output when run with two tasks:
DNDarray([[3., 3., 3., 3.],
[6., 6., 6., 6.],
[3., 3., 3., 3.],
[6., 6., 6., 6.]], dtype=ht.float32, device=cpu:0, split=0)
DNDarray([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]], dtype=ht.float32, device=cpu:0, split=0)Version
main (development branch)
Python version
3.13
PyTorch version
2.9
MPI version
OpenMPI 5.0.8Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Todo