Skip to content

BUG: ZarrTrace unnecessarily loads all data into memory after sampling is done #8104

@avm19

Description

@avm19

Describe the issue:

When ZarrTrace is used in pymc.sample(), entire sampled arrays are unnecessarily loaded into memory at the finalisation stage after the actual sampling is done, thereby crashing the process by running out of memory in certain workloads.

Digging deeper into the issue, the culprit is pymc.backends.zarr.py::ZarrTrace.split_warmup(), these lines

pymc/pymc/backends/zarr.py

Lines 795 to 803 in 212ed1c

warmup_array = warmup_group.array(
name=name,
data=array[warmup_idx],
chunks=array.chunks,
dtype=dtype,
fill_value=fill_value,
object_codec=object_codec,
compressor=self.compressor,
)

In line 797, array is a Zarr.Array whose __getitem__() returns a numpy.NDArray, that is, loads the entire array into memory.

Reproduceable code example:

import numpy as np
import zarr
import zarr.storage
import pymc as pm
from pymc.backends.zarr import ZarrTrace

n1, n2 = 300, 300
coords = {"coord1": range(n1), "coord2": range(n2)}
with pm.Model(coords=coords) as model:
    mu1 = pm.Data("mu1", np.zeros(shape=n1), dims="coord1")
    mu2 = pm.Data("mu2", np.zeros(shape=n2), dims="coord2")
    x1 = pm.Normal("x1", mu=mu1, sigma=1.0, dims="coord1")
    x2 = pm.Normal("x2", mu=mu2, sigma=2.0, dims="coord2")
    z = pm.Deterministic("z", x1[:, None] * x2[None, :])
pm.model_to_graphviz(model, graph_attr={"size": "3"})

store = zarr.storage.DirectoryStore("zarr")
zt = ZarrTrace(store)
with model:
    idata = pm.sample(draws=5_000,
        tune=1_000,
        trace=zt,
        return_inferencedata=False,  # no post-processing needed
        discard_tuned_samples=False,  # no post-processing needed
        compute_convergence_checks=False,  # no post-processing needed
    )

Error message:

The process runs out of memory.

PyMC version information:

0+untagged.1002.g440ca46
0+untagged.1077.g212ed1c (HEAD commit of main at the time of writing)

Context for the issue:

One would use Zarr to store samples on disk (or other storage) in order to (1) have them in case of interruption and/or (2) use less main memory. As of now, use case (1) requires manual recovery, and sampling cannot be resumed. While use case (2) is effectively prevented by this bug. This makes ZarrTrace not very useful at all in its current implementation. I know that some work is in progress to address this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions