-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Describe the issue:
When ZarrTrace is used in pymc.sample(), entire sampled arrays are unnecessarily loaded into memory at the finalisation stage after the actual sampling is done, thereby crashing the process by running out of memory in certain workloads.
Digging deeper into the issue, the culprit is pymc.backends.zarr.py::ZarrTrace.split_warmup(), these lines
Lines 795 to 803 in 212ed1c
| warmup_array = warmup_group.array( | |
| name=name, | |
| data=array[warmup_idx], | |
| chunks=array.chunks, | |
| dtype=dtype, | |
| fill_value=fill_value, | |
| object_codec=object_codec, | |
| compressor=self.compressor, | |
| ) |
In line 797, array is a Zarr.Array whose __getitem__() returns a numpy.NDArray, that is, loads the entire array into memory.
Reproduceable code example:
import numpy as np
import zarr
import zarr.storage
import pymc as pm
from pymc.backends.zarr import ZarrTrace
n1, n2 = 300, 300
coords = {"coord1": range(n1), "coord2": range(n2)}
with pm.Model(coords=coords) as model:
mu1 = pm.Data("mu1", np.zeros(shape=n1), dims="coord1")
mu2 = pm.Data("mu2", np.zeros(shape=n2), dims="coord2")
x1 = pm.Normal("x1", mu=mu1, sigma=1.0, dims="coord1")
x2 = pm.Normal("x2", mu=mu2, sigma=2.0, dims="coord2")
z = pm.Deterministic("z", x1[:, None] * x2[None, :])
pm.model_to_graphviz(model, graph_attr={"size": "3"})
store = zarr.storage.DirectoryStore("zarr")
zt = ZarrTrace(store)
with model:
idata = pm.sample(draws=5_000,
tune=1_000,
trace=zt,
return_inferencedata=False, # no post-processing needed
discard_tuned_samples=False, # no post-processing needed
compute_convergence_checks=False, # no post-processing needed
)Error message:
The process runs out of memory.PyMC version information:
0+untagged.1002.g440ca46
0+untagged.1077.g212ed1c (HEAD commit of main at the time of writing)
Context for the issue:
One would use Zarr to store samples on disk (or other storage) in order to (1) have them in case of interruption and/or (2) use less main memory. As of now, use case (1) requires manual recovery, and sampling cannot be resumed. While use case (2) is effectively prevented by this bug. This makes ZarrTrace not very useful at all in its current implementation. I know that some work is in progress to address this.