Generic, queue-based, multi-GPU PyTorch inference pipeline for large volumetric (3D or 5D T,C,Z,Y,X) data.
- Block + patch based tiled volume processing with optional overlap / halo
- Multi-threaded CPU stages (prep / writer) feeding multi-GPU inference workers
- Seam handling via trimming or weighted blending
- Optional global or per-block percentile normalization
- Queue & system monitoring (JSON metrics export)
- Model registry for plug‑and‑play custom architectures
- Optional AMP, TF32, and
torch.compile
acceleration - Backed by
tensorstore
for flexible IO backends
User (editable):
pip install -e .
Development (linting, docs, tests):
python3 -m pip install --upgrade pip
pip install -e . --group dev
Optional extras (e.g. UNet dependency):
pip install -e .[denoise-net]
python -m aind_torch_utils.run \
--in-spec '{"driver": "zarr", "kvstore": "s3://my-bucket/in.zarr/0"}' \
--out-spec '{"driver":"zarr","kvstore":{"driver":"s3","bucket":"my-bucket","path":"out.zarr"},"path":"0","metadata":{"shape":[1,1,1024,1024,1024],"chunks":[1,1,256,256,256],"dtype":"<u2"},"create":true,"delete_existing":true}' \
--model-type denoise-net \
--weights /data/BM4DNet-20250905-169-0.0073.pth \
--t 0 --c 0 \
--patch 64 64 64 \
--overlap 12 \
--block 256 256 256 \
--batch 64 \
--devices cuda:0 cuda:1 \
--seam-mode trim \
--trim-voxels 6 \
--halo 8 \
--max-inflight-batches 64 \
--norm-lower 0.5 --norm-upper 99.9 \
--min-blend-weight 0.05 \
--prep-workers 4 \
--writer-workers 4 \
--metrics-json metrics.json \
--metrics-interval 0.5
python -m aind_torch_utils.run \
--in-spec "data/in_spec.json" \
--out-spec "data/out_spec.json" \
--model-type denoise-net \
--weights /data/BM4DNet-20250905-169-0.0073.pth
from aind_torch_utils.config import InferenceConfig
from aind_torch_utils.run import run
from aind_torch_utils.model_registry import ModelRegistry
from aind_torch_utils.utils import open_ts_spec
in_spec = {
"driver": "zarr",
"kvstore": "s3://bucket/path/input.zarr/0",
}
out_spec = {
"driver": "zarr",
"kvstore": {
"driver": "s3",
"bucket": "bucket",
"path": "ns/out.zarr"
},
"path": "0",
"metadata": {
"shape": [1, 1, 256, 1024, 1024],
"chunks": [1, 1, 256, 256, 256],
"dtype": "<u2"
},
"create": True,
"delete_existing": True
}
input_store = open_ts_spec(in_spec)
output_store = open_ts_spec(out_spec)
cfg = InferenceConfig(
patch=(64,64,64),
overlap=10,
block=(256,256,256),
batch_size=16,
devices=["cuda:0","cuda:1"],
seam_mode="trim",
trim_voxels=5,
)
model = ModelRegistry.load_model("denoise-net", weights_path="weights.pth")
run(
model=model,
input_store=input_store,
output_store=output_store,
cfg=cfg,
metrics_json="metrics.json",
num_prep_workers=4,
num_writer_workers=2,
)
Configuration Note: The example above only sets a subset of available fields. For the complete list of parameters, validation rules, and detailed descriptions, open
src/aind_torch_utils/config.py
and review theInferenceConfig
class doc/Field metadata.
If --metrics-json
provided:
- Queue depths over time (prep & writer queues)
- System metrics (CPU %, RAM, GPU (if implemented)) Use to diagnose stalls (e.g., GPU idle while prep queue empty => increase prep workers / decrease IO latency).
Add your model directly to src/aind_torch_utils/models.py
so it is automatically available when the package is imported.
Steps:
- Open
src/aind_torch_utils/models.py
. - Define your
nn.Module
subclass. - Add a loader function that returns an instance (optionally loading weights) and decorate it with
@ModelRegistry.register("your-name")
.
Snippet (showing how to append under the existing UNet registration):
# ... existing imports and UNet registration ...
from typing import Optional
import torch
from torch import nn
class MySimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv3d(1, 1, 3, padding=1)
def forward(self, x): # x: (N, C=1, Z, Y, X)
return self.conv(x)
@ModelRegistry.register("simple-net")
def load_simple_net(weights_path: Optional[str] = None) -> nn.Module:
model = MySimpleNet()
if weights_path:
sd = torch.load(weights_path, map_location="cpu")
model.load_state_dict(sd)
return model
You can now run:
python -m aind_torch_utils.run --model-type simple-net ...
From Python:
from aind_torch_utils.model_registry import ModelRegistry
print(ModelRegistry.list_models())
- Run formatting:
black . && isort .
- Run tests:
pytest -q
- Add docstrings; keep public API minimal.
MIT