Skip to content

Conversation

ronamit
Copy link

@ronamit ronamit commented Sep 29, 2025

Add batch_converter hook + multiprocessing support to experimental.pytorch.AnnLoader

This PR solves two major AnnLoader limitations:

  1. Batch-level transformation: Add batch_converter parameter for advanced post-processing
  2. Multiprocessing support: Enable num_workers > 0 (previously crashed due to unpicklable AnnCollectionView)

Motivation

AnnLoader currently offers only element-wise converters via convert["X"] mapping, insufficient for:

  • Returning dict instead of AnnCollectionView
  • On-the-fly augmentations using both .X and .obs
  • PyTorch Lightning integration with specific batch signatures
  • Adding derived fields or metadata to batches

Additionally, num_workers > 0 has never worked, preventing parallel data loading in production workflows.

Implementation

1. batch_converter parameter

AnnLoader(adata, batch_size=128, batch_converter=my_fn, num_workers=4)

Optional callable applied to each batch before returning to user. Fully backward-compatible.

2. Multiprocessing support

  • __getstate__/__setstate__ hooks enable AnnCollectionView pickling across worker processes
  • Worker-side batch conversion via custom collate_fn
  • Per-worker HDF5 file handles following h5py best practices

3. Helper converter

batch_dict_converter converts AnnCollectionViewdict[str, Tensor] with "x" key for .X and keys for each .obs column.

Usage

from anndata.experimental.pytorch import AnnLoader, batch_dict_converter

# Single-threaded
loader = AnnLoader(adata, batch_size=256, batch_converter=batch_dict_converter, num_workers=0)

# Multi-threaded (now possible)
loader = AnnLoader(adata, batch_size=256, batch_converter=batch_dict_converter, num_workers=4)

for batch in loader:
    x = batch["x"]          # torch.Tensor
    obs_fields = batch["cell_type"]  # obs columns as tensors

Performance

scenario master this PR
Single-threaded 133 ms 134 ms (+0.8%)
Multiprocessing Crashes Works

Negligible overhead for single-threaded use. Primary benefit is enabling parallel data loading for I/O-bound workflows.

Testing

  • Comprehensive unit tests for both single and multi-threaded modes
  • Validates backward compatibility (no converter → AnnCollectionView)
  • Documents multiprocessing necessity vs vanilla DataLoader

Backward compatible: Optional parameter defaults to None, no breaking changes.

ronamit and others added 15 commits September 29, 2025 11:04
- Add batch_converter parameter for advanced batch-level post-processing
- Enable multiprocessing (num_workers>0) via AnnCollectionView pickling
- Implement worker-side batch conversion for true parallelism
- Add comprehensive tests for both single and multi-threaded modes
- Include helper batch_dict_converter for common dict format
- All tests pass, pre-commit hooks clean, backward compatible

Solves two major AnnLoader limitations in unified implementation:
1. Batch-level transformation (vs element-wise convert)
2. Multiprocessing support (was broken due to unpicklable AnnCollectionView)

Enables production ML workflows with PyTorch Lightning integration,
data augmentation, balanced sampling, and parallel data loading.
…pport

- Fix AnnLoader docstring to remove incorrect multiprocessing limitation
- Add batch_dict_converter to API documentation
- Add release notes fragment for PR scverse#2135
- Document new batch_converter parameter and multiprocessing capabilities

The batch_converter parameter now works seamlessly with both single-threaded
and multi-threaded data loading, enabling faster PyTorch training workflows.
- Move tests from src/anndata/tests/pytorch/ to tests/pytorch/
- Follow standard anndata test organization pattern
- Add __init__.py to make pytorch test package discoverable
- Tests for batch_converter parameter and multiprocessing support
- Add conditional torch imports using find_spec() pattern in converters.py
- Make batch_dict_converter import conditional with helpful error message
- Add pytest.importorskip('torch') to PyTorch test files
- Fix linting warnings by using += instead of .append() for __all__

This resolves CI test collection errors when torch is not available while
maintaining full functionality when torch is installed.
The CI tests were failing due to DeprecationWarning: 'oneOf' deprecated - use 'one_of'
warnings from pyparsing used by matplotlib. This warning is triggered when scanpy
imports matplotlib during test execution.

Added warning filters to pytest configuration to ignore this specific deprecation
warning from the dependency chain, allowing tests to pass while preserving
warnings from anndata's own code.

Fixes the following failing tests:
- test_scanpy_pbmc68k[zarr3/zarr2]
- test_scanpy_krumsiek11[zarr3/zarr2]
- test_read_partial_adata[zarr2/zarr3]
After fixing the 'oneOf' deprecation warnings, CI revealed another pyparsing
deprecation warning: 'parseString' deprecated - use 'parse_string'. This is
also coming from matplotlib's font configuration parsing.

Added additional warning filter to suppress this specific deprecation warning
from the dependency chain, ensuring all pyparsing-related deprecation warnings
from matplotlib are properly handled in CI.
After fixing the 'oneOf' deprecation warnings, CI revealed another pyparsing
deprecation warning: 'parseString' deprecated - use 'parse_string'. This is
also coming from matplotlib's font configuration parsing.

Added additional warning filter to suppress this specific deprecation warning
from the dependency chain, ensuring all pyparsing-related deprecation warnings
from matplotlib are properly handled in CI.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant