Skip to content

Commit bc413bf

Browse files
committed
Add tests for feature extraction
1 parent 55890b1 commit bc413bf

File tree

9 files changed

+142
-14
lines changed

9 files changed

+142
-14
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ all = ["stamp[dinobloom,conch,ctranspath,uni,virchow2]"]
7676

7777
[dependency-groups]
7878
dev = [
79+
"huggingface-hub>=0.27.1",
7980
"ipykernel>=6.29.5",
8081
"pyright>=1.1.389,!=1.1.391",
8182
"pytest>=8.3.4",
@@ -88,4 +89,4 @@ build-backend = "hatchling.build"
8889

8990
[tool.hatch.metadata]
9091
# To allow referencing git repos in dependencies
91-
allow-direct-references = true
92+
allow-direct-references = true

src/stamp/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _run_cli(args: argparse.Namespace) -> None:
7272
tile_size_px=config.preprocessing.tile_size_px,
7373
extractor=config.preprocessing.extractor,
7474
max_workers=config.preprocessing.max_workers,
75-
accelerator=config.preprocessing.accelerator,
75+
device=config.preprocessing.device,
7676
brightness_cutoff=config.preprocessing.brightness_cutoff,
7777
canny_cutoff=config.preprocessing.canny_cutoff,
7878
)

src/stamp/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
def download_file(*, url: str, file_name: str, sha256sum: str) -> Path:
18+
"""Downloads a file, or loads it from cache if it has been downloaded before"""
1819
outfile_path = STAMP_CACHE_DIR / file_name
1920
if outfile_path.is_file():
2021
with open(outfile_path, "rb") as weight_file:

src/stamp/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ preprocessing:
77
extractor: "ctranspath"
88

99
# Device to run feature extraction on ("cpu", "cuda", "cuda:0", etc.)
10-
accelerator: "cuda"
10+
device: "cuda"
1111

1212
# Optional settings:
1313

src/stamp/preprocessing/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import hashlib
22
import logging
3-
from collections.abc import Callable
3+
from collections.abc import Callable, Iterator
44
from functools import cache
55
from pathlib import Path
66
from random import shuffle
77
from tempfile import NamedTemporaryFile
8-
from typing import Iterator, assert_never
8+
from typing import assert_never
99

1010
import h5py
1111
import numpy as np
@@ -18,6 +18,7 @@
1818
from torch.utils.data import DataLoader, IterableDataset
1919
from tqdm import tqdm
2020

21+
import stamp
2122
from stamp.preprocessing.config import ExtractorName
2223
from stamp.preprocessing.extractor import Extractor
2324
from stamp.preprocessing.tiling import (
@@ -120,7 +121,7 @@ def extract_(
120121
tile_size_px: TilePixels,
121122
tile_size_um: Microns,
122123
max_workers: int,
123-
accelerator: DeviceLikeType,
124+
device: DeviceLikeType,
124125
brightness_cutoff: int | None,
125126
canny_cutoff: float | None,
126127
) -> None:
@@ -161,7 +162,7 @@ def extract_(
161162
case _ as unreachable:
162163
assert_never(unreachable)
163164

164-
model = extractor.model.to(accelerator).eval()
165+
model = extractor.model.to(device).eval()
165166
extractor_id = f"{extractor.identifier}-{_get_preprocessing_code_hash()[:8]}"
166167

167168
logger.info(f"Using extractor {extractor.identifier}")
@@ -213,7 +214,7 @@ def extract_(
213214
feats, xs_um, ys_um = [], [], []
214215
for tiles, xs, ys in tqdm(dl, leave=False):
215216
with torch.inference_mode():
216-
feats.append(model(tiles.to(accelerator)).detach().half().cpu())
217+
feats.append(model(tiles.to(device)).detach().half().cpu())
217218
xs_um.append(xs.float())
218219
ys_um.append(ys.float())
219220
except Exception:
@@ -235,6 +236,7 @@ def extract_(
235236
h5_fp["coords"] = coords
236237
h5_fp["feats"] = torch.concat(feats).numpy()
237238

239+
h5_fp.attrs["stamp_version"] = stamp.__version__
238240
h5_fp.attrs["extractor"] = extractor_id
239241
h5_fp.attrs["unit"] = "um"
240242
h5_fp.attrs["tile_size"] = tile_size_um

src/stamp/preprocessing/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True):
3131
tile_size_px: TilePixels = TilePixels(224)
3232
extractor: ExtractorName
3333
max_workers: int = 8
34-
accelerator: DeviceLikeType = "cuda" if torch.cuda.is_available() else "cpu"
34+
device: DeviceLikeType = "cuda" if torch.cuda.is_available() else "cpu"
3535

3636
# Background rejection
3737
brightness_cutoff: int | None = Field(240, gt=0, lt=255)

src/stamp/preprocessing/extractor/empty.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@ class _EmptyModel(torch.nn.Module):
2121
def forward(
2222
self, batch: Float[torch.Tensor, "batch channel height width"]
2323
) -> Float[torch.Tensor, "batch feature"]:
24-
return torch.zeros(batch.size(0)).type_as(batch)
24+
return torch.zeros(batch.size(0), 0).type_as(batch)
2525

2626

2727
def empty() -> Extractor:
2828
return Extractor(
2929
model=_EmptyModel(),
30-
transform=torchvision.transforms.functional.pil_to_tensor,
30+
transform=torchvision.transforms.Compose(
31+
[
32+
torchvision.transforms.PILToTensor(),
33+
torchvision.transforms.Lambda(lambda x: x.float()),
34+
]
35+
),
3136
identifier="empty",
3237
)

tests/test_feature_extractors.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
5+
import h5py
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from huggingface_hub.errors import GatedRepoError
10+
11+
from stamp.cache import download_file
12+
from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_
13+
14+
15+
def test_if_feature_extraction_crashes(extractor=ExtractorName.CTRANSPATH) -> None:
16+
example_slide_path = download_file(
17+
url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs",
18+
file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs",
19+
sha256sum="9b7d2b0294524351bf29229c656cc886af028cb9e7463882289fac43c1347525",
20+
)
21+
with tempfile.TemporaryDirectory(prefix="stamp_test_preprocessing_") as tmp_dir:
22+
dir = Path(tmp_dir)
23+
wsi_dir = dir / "wsis"
24+
wsi_dir.mkdir()
25+
(wsi_dir / "slide.svs").symlink_to(example_slide_path)
26+
27+
try:
28+
extract_(
29+
wsi_dir=wsi_dir,
30+
output_dir=dir / "output",
31+
extractor=extractor,
32+
cache_dir=None,
33+
tile_size_px=TilePixels(224),
34+
tile_size_um=Microns(256.0),
35+
max_workers=min(os.cpu_count() or 1, 16),
36+
brightness_cutoff=224,
37+
canny_cutoff=0.02,
38+
device="cuda" if torch.cuda.is_available() else "cpu",
39+
)
40+
except ModuleNotFoundError:
41+
pytest.skip(f"dependencies for {extractor} not installed")
42+
except GatedRepoError:
43+
pytest.skip(f"cannot access gated repo for {extractor}")
44+
45+
# Check if the file has any contents
46+
with h5py.File(next((dir / "output").glob("*/*.h5"))) as h5_file:
47+
just_extracted_feats = np.array(h5_file["feats"][:]) # pyright: ignore[reportIndexIssue]
48+
49+
assert len(just_extracted_feats) > 0
50+
51+
52+
def test_if_conch_feature_extraction_crashes() -> None:
53+
test_if_feature_extraction_crashes(ExtractorName.CONCH)
54+
55+
56+
def test_if_uni_feature_extraction_crashes() -> None:
57+
test_if_feature_extraction_crashes(ExtractorName.UNI)
58+
59+
60+
def test_if_dino_bloom_feature_extraction_crashes() -> None:
61+
test_if_feature_extraction_crashes(ExtractorName.DINO_BLOOM)
62+
63+
64+
def test_if_virchow2_feature_extraction_crashes() -> None:
65+
test_if_feature_extraction_crashes(ExtractorName.VIRCHOW2)
66+
67+
68+
def test_if_empty_feature_extraction_crashes() -> None:
69+
test_if_feature_extraction_crashes(ExtractorName.EMPTY)
70+
71+
72+
def check_backward_compatability(extractor=ExtractorName.CTRANSPATH) -> None:
73+
example_slide_path = download_file(
74+
url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs",
75+
file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs",
76+
sha256sum="9b7d2b0294524351bf29229c656cc886af028cb9e7463882289fac43c1347525",
77+
)
78+
with tempfile.TemporaryDirectory(prefix="stamp_test_preprocessing_") as tmp_dir:
79+
dir = Path(tmp_dir)
80+
wsi_dir = dir / "wsis"
81+
wsi_dir.mkdir()
82+
(wsi_dir / "slide.svs").symlink_to(example_slide_path)
83+
84+
try:
85+
extract_(
86+
wsi_dir=wsi_dir,
87+
output_dir=dir / "output",
88+
extractor=extractor,
89+
cache_dir=None,
90+
tile_size_px=TilePixels(224),
91+
tile_size_um=Microns(256.0),
92+
max_workers=min(os.cpu_count() or 1, 16),
93+
brightness_cutoff=224,
94+
canny_cutoff=0.02,
95+
device="cuda" if torch.cuda.is_available() else "cpu",
96+
)
97+
except ModuleNotFoundError:
98+
pytest.skip(f"dependencies for {extractor} not installed")
99+
100+
reference_feature_path = download_file(
101+
url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/ctranspath-TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.h5",
102+
file_name="ctranspath-TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.h5",
103+
sha256sum="f3f33b069c3ed860d2bdb7d65ca5db64936d7acee3ba1061a457a8cdb1bc67e3",
104+
)
105+
106+
with h5py.File(reference_feature_path) as h5_file:
107+
reference_feats = h5_file["feats"][:] # pyright: ignore[reportIndexIssue]
108+
reference_version = h5_file.attrs["stamp_version"]
109+
110+
with h5py.File(next((dir / "output").glob("*/*.h5"))) as h5_file:
111+
just_extracted_feats = h5_file["feats"][:] # pyright: ignore[reportIndexIssue]
112+
113+
assert torch.allclose(
114+
torch.tensor(just_extracted_feats), torch.tensor(reference_feats)
115+
), (
116+
f"extracted {extractor} features differ from those made with stamp version {reference_version}"
117+
)

uv.lock

Lines changed: 5 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)