Skip to content

Commit 5b562b8

Browse files
authored
Add Barspoon for multi-target prediction (v2.4.1 patch-1) (#158)
* add multi-target support; tests, fixes, and docs * add multi-target support * add multi-target statistics * refactor * refactor survival training/validation * refactor survival training/validation * Remove unused import from survival.py Removed unused import for add_at_risk_counts. * update data.py to latest v2.4.1 * update tests to latest v2.4.1 * remove misc test
1 parent 61ce63c commit 5b562b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2883
-760
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha
1919
* 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research.
2020
* 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*).
2121
* 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required.
22-
* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**.
22+
* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**.
2323
* 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting.
2424
* 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures.
2525
* 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility.

src/stamp/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
import yaml
88

9-
from stamp.config import StampConfig
109
from stamp.modeling.config import (
1110
AdvancedConfig,
1211
MlpModelParams,
1312
ModelParams,
1413
VitModelParams,
1514
)
16-
from stamp.seed import Seed
15+
from stamp.utils.config import StampConfig
16+
from stamp.utils.seed import Seed
1717

1818
STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")
1919

src/stamp/config.yaml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ preprocessing:
44
# Extractor to use for feature extractor. Possible options are "ctranspath",
55
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
66
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
7-
# "virchow-full", "musk", "mstar", "plip"
7+
# "virchow-full", "musk", "mstar", "plip", "ticon"
88
# Some of them require requesting access to the respective authors beforehand.
99
extractor: "chief-ctranspath"
1010

@@ -76,6 +76,8 @@ crossval:
7676

7777
# Name of the column from the clini table to train on.
7878
ground_truth_label: "KRAS"
79+
# For multi-target classification you may specify a list of columns,
80+
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]
7981

8082
# For survival (should be status and follow-up days columns in clini table)
8183
# status_label: "event"
@@ -133,6 +135,8 @@ training:
133135

134136
# Name of the column from the clini table to train on.
135137
ground_truth_label: "KRAS"
138+
# For multi-target classification you may specify a list of columns,
139+
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]
136140

137141
# For survival (should be status and follow-up days columns in clini table)
138142
# status_label: "event"
@@ -175,6 +179,8 @@ deployment:
175179

176180
# Name of the column from the clini to compare predictions to.
177181
ground_truth_label: "KRAS"
182+
# For multi-target classification you may specify a list of columns,
183+
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]
178184

179185
# For survival (should be status and follow-up days columns in clini table)
180186
# status_label: "event"
@@ -200,6 +206,8 @@ statistics:
200206

201207
# Name of the target label.
202208
ground_truth_label: "KRAS"
209+
# For multi-target classification you may specify a list of columns,
210+
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]
203211

204212
# A lot of the statistics are computed "one-vs-all", i.e. there needs to be
205213
# a positive class to calculate the statistics for.
@@ -319,7 +327,7 @@ advanced_config:
319327
max_lr: 1e-4
320328
div_factor: 25.
321329
# Select a model regardless of task
322-
model_name: "vit" # or mlp, trans_mil
330+
model_name: "vit" # or mlp, trans_mil, barspoon
323331

324332
model_params:
325333
vit: # Vision Transformer
@@ -338,3 +346,15 @@ advanced_config:
338346
dim_hidden: 512
339347
num_layers: 2
340348
dropout: 0.25
349+
350+
# NOTE: Only the `barspoon` model supports multi-target classification
351+
# (i.e. `ground_truth_label` can be a list of column names). Other
352+
# models expect a single target column.
353+
barspoon: # Encoder-Decoder Transformer for multi-target classification
354+
d_model: 512
355+
num_encoder_heads: 8
356+
num_decoder_heads: 8
357+
num_encoder_layers: 2
358+
num_decoder_layers: 2
359+
dim_feedforward: 2048
360+
positional_encoding: true

src/stamp/encoding/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def init_slide_encoder_(
7373
selected_encoder = encoder
7474

7575
case _ as unreachable:
76-
assert_never(unreachable) # type: ignore
76+
assert_never(unreachable)
7777

7878
selected_encoder.encode_slides_(
7979
output_dir=output_dir,

src/stamp/encoding/encoder/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import ABC, abstractmethod
55
from pathlib import Path
66
from tempfile import NamedTemporaryFile
7+
from typing import cast
78

89
import h5py
910
import numpy as np
@@ -12,11 +13,11 @@
1213
from tqdm import tqdm
1314

1415
import stamp
15-
from stamp.cache import get_processing_code_hash
1616
from stamp.encoding.config import EncoderName
1717
from stamp.modeling.data import CoordsInfo, get_coords, read_table
1818
from stamp.preprocessing.config import ExtractorName
1919
from stamp.types import DeviceLikeType, PandasLabel
20+
from stamp.utils.cache import get_processing_code_hash
2021

2122
__author__ = "Juan Pablo Ricapito"
2223
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
@@ -183,7 +184,8 @@ def _read_h5(
183184
elif not h5_path.endswith(".h5"):
184185
raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}")
185186
with h5py.File(h5_path, "r") as f:
186-
feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore
187+
feats_ds = cast(h5py.Dataset, f["feats"])
188+
feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision)
187189
coords: CoordsInfo = get_coords(f)
188190
extractor: str = f.attrs.get("extractor", "")
189191
if extractor == "":

src/stamp/encoding/encoder/chief.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from numpy import ndarray
1111
from tqdm import tqdm
1212

13-
from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash
1413
from stamp.encoding.config import EncoderName
1514
from stamp.encoding.encoder import Encoder
1615
from stamp.preprocessing.config import ExtractorName
1716
from stamp.types import DeviceLikeType, PandasLabel
17+
from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash
1818

1919
__author__ = "Juan Pablo Ricapito"
2020
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"

src/stamp/encoding/encoder/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from torch import Tensor
1010
from tqdm import tqdm
1111

12-
from stamp.cache import get_processing_code_hash
1312
from stamp.encoding.config import EncoderName
1413
from stamp.encoding.encoder import Encoder
1514
from stamp.encoding.encoder.chief import CHIEF
1615
from stamp.modeling.data import CoordsInfo
1716
from stamp.preprocessing.config import ExtractorName
1817
from stamp.types import DeviceLikeType, PandasLabel
18+
from stamp.utils.cache import get_processing_code_hash
1919

2020
__author__ = "Juan Pablo Ricapito"
2121
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"

src/stamp/encoding/encoder/gigapath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from gigapath import slide_encoder
1010
from tqdm import tqdm
1111

12-
from stamp.cache import get_processing_code_hash
1312
from stamp.encoding.config import EncoderName
1413
from stamp.encoding.encoder import Encoder
1514
from stamp.modeling.data import CoordsInfo
1615
from stamp.preprocessing.config import ExtractorName
1716
from stamp.types import PandasLabel, SlideMPP
17+
from stamp.utils.cache import get_processing_code_hash
1818

1919
__author__ = "Juan Pablo Ricapito"
2020
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"

src/stamp/encoding/encoder/madeleine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import torch
44
from numpy import ndarray
55

6-
from stamp.cache import STAMP_CACHE_DIR
76
from stamp.encoding.config import EncoderName
87
from stamp.encoding.encoder import Encoder
98
from stamp.preprocessing.config import ExtractorName
9+
from stamp.utils.cache import STAMP_CACHE_DIR
1010

1111
try:
1212
from madeleine.models.factory import create_model_from_pretrained

src/stamp/encoding/encoder/titan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from tqdm import tqdm
1111
from transformers import AutoModel
1212

13-
from stamp.cache import get_processing_code_hash
1413
from stamp.encoding.config import EncoderName
1514
from stamp.encoding.encoder import Encoder
1615
from stamp.modeling.data import CoordsInfo
1716
from stamp.preprocessing.config import ExtractorName
1817
from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP
18+
from stamp.utils.cache import get_processing_code_hash
1919

2020
__author__ = "Juan Pablo Ricapito"
2121
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"

0 commit comments

Comments
 (0)