Skip to content

Commit 95e6a10

Browse files
authored
Merge pull request #132 from KatherLab/dev/multitask
Dev/multitask - STAMP extended for regression and survival analysis (v2.4 update)
2 parents 9323558 + 7b3a27a commit 95e6a10

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4755
-1518
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha
1919
* 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research.
2020
* 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*).
2121
* 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required.
22-
* 📊 **Stats & results**: Built‑in metrics (AUROC/AUPRC \+ 95% CI) and patient‑level predictions, ready for analysis and reporting.
22+
* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**.
23+
* 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting.
2324
* 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures.
2425
* 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility.
2526
* 📑 **Peer‑reviewed**: Protocol published in [*Nature Protocols*](https://www.nature.com/articles/s41596-024-01047-2) and validated across multiple tumor types and centers.

getting-started.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,45 @@ heatmaps:
471471
```
472472

473473

474+
## Advanced configuration
475+
476+
Advanced experiment settings can be specified under the `advanced_config` section in your configuration file.
477+
This section lets you control global training parameters, model type, and the target task (classification, regression, or survival).
478+
479+
```yaml
480+
# stamp-test-experiment/config.yaml
481+
482+
advanced_config:
483+
seed: 42
484+
task: "classification" # or regression/survial
485+
max_epochs: 32
486+
patience: 16
487+
batch_size: 64
488+
# Only for tile-level training. Reducing its amount could affect
489+
# model performance. Reduces memory consumption. Default value works
490+
# fine for most cases.
491+
bag_size: 512
492+
#num_workers: 16 # Default chosen by cpu cores
493+
# One Cycle Learning Rate Scheduler parameters. Check docs for more info.
494+
# Determines the initial learning rate via initial_lr = max_lr/div_factor
495+
max_lr: 1e-4
496+
div_factor: 25.
497+
# Select a model regardless of task
498+
# Available models are: vit, trans_mil, mlp
499+
model_name: "vit"
500+
501+
model_params:
502+
vit: # Vision Transformer
503+
dim_model: 512
504+
dim_feedforward: 512
505+
n_heads: 8
506+
n_layers: 2
507+
dropout: 0.25
508+
use_alibi: false
509+
```
510+
511+
STAMP automatically adapts its **model architecture**, **loss function**, and **evaluation metrics** based on the task specified in the configuration file.
512+
513+
**Regression** tasks only require `ground_truth_label`.
514+
**Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator).
515+
These requirements apply consistently across cross-validation, training, deployment, and statistics.

mcp/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import asyncio
22
import logging
33
import os
4-
from pathlib import Path
54
import platform
65
import subprocess
76
import tempfile
7+
from pathlib import Path
88
from typing import Annotated
99

1010
import torch

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "stamp"
3-
version = "2.3.0"
3+
version = "2.4.0"
44
authors = [
55
{ name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" },
66
{ name = "Marko van Treeck", email = "markovantreeck@gmail.com" },
@@ -9,7 +9,8 @@ authors = [
99
{ name = "Laura Žigutytė", email = "laura.zigutyte@tu-dresden.de" },
1010
{ name = "Cornelius Kummer", email = "cornelius.kummer@tu-dresden.de" },
1111
{ name = "Juan Pablo Ricapito", email = "juan_pablo.ricapito@tu-dresden.de" },
12-
{ name = "Fabian Wolf", email = "fabian.wolf2@tu-dresden.de" }
12+
{ name = "Fabian Wolf", email = "fabian.wolf2@tu-dresden.de" },
13+
{ name = "Minh Duc Nguyen", email = "minh_duc.nguyen1@tu-dresden.de" }
1314
]
1415
description = "A protocol for Solid Tumor Associative Modeling in Pathology"
1516
readme = "README.md"
@@ -45,7 +46,8 @@ dependencies = [
4546
"torchvision>=0.22.1",
4647
"tqdm>=4.67.1",
4748
"timm>=1.0.19",
48-
"transformers>=4.55.0"
49+
"transformers>=4.55.0",
50+
"lifelines>=0.28.0",
4951
]
5052

5153
[project.optional-dependencies]
@@ -84,7 +86,6 @@ gigapath = [
8486
"monai",
8587
"scikit-image",
8688
"webdataset",
87-
"lifelines",
8889
"scikit-survival>=0.24.1",
8990
"fairscale",
9091
"wandb",

src/stamp/__main__.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _run_cli(args: argparse.Namespace) -> None:
5353
# use default advanced config in case none is provided
5454
if config.advanced_config is None:
5555
config.advanced_config = AdvancedConfig(
56-
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
56+
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()),
5757
)
5858

5959
# Set global random seed
@@ -65,7 +65,7 @@ def _run_cli(args: argparse.Namespace) -> None:
6565
raise RuntimeError("this case should be handled above")
6666

6767
case "config":
68-
print(yaml.dump(config.model_dump(mode="json")))
68+
print(yaml.dump(config.model_dump(mode="json", exclude_none=True)))
6969

7070
case "preprocess":
7171
from stamp.preprocessing import extract_
@@ -76,7 +76,7 @@ def _run_cli(args: argparse.Namespace) -> None:
7676
_add_file_handle_(_logger, output_dir=config.preprocessing.output_dir)
7777
_logger.info(
7878
"using the following configuration:\n"
79-
f"{yaml.dump(config.preprocessing.model_dump(mode='json'))}"
79+
f"{yaml.dump(config.preprocessing.model_dump(mode='json', exclude_none=True))}"
8080
)
8181
extract_(
8282
output_dir=config.preprocessing.output_dir,
@@ -104,7 +104,7 @@ def _run_cli(args: argparse.Namespace) -> None:
104104
_add_file_handle_(_logger, output_dir=config.slide_encoding.output_dir)
105105
_logger.info(
106106
"using the following configuration:\n"
107-
f"{yaml.dump(config.slide_encoding.model_dump(mode='json'))}"
107+
f"{yaml.dump(config.slide_encoding.model_dump(mode='json', exclude_none=True))}"
108108
)
109109
init_slide_encoder_(
110110
encoder=config.slide_encoding.encoder,
@@ -124,7 +124,7 @@ def _run_cli(args: argparse.Namespace) -> None:
124124
_add_file_handle_(_logger, output_dir=config.patient_encoding.output_dir)
125125
_logger.info(
126126
"using the following configuration:\n"
127-
f"{yaml.dump(config.patient_encoding.model_dump(mode='json'))}"
127+
f"{yaml.dump(config.patient_encoding.model_dump(mode='json', exclude_none=True))}"
128128
)
129129
init_patient_encoder_(
130130
encoder=config.patient_encoding.encoder,
@@ -147,9 +147,12 @@ def _run_cli(args: argparse.Namespace) -> None:
147147
_add_file_handle_(_logger, output_dir=config.training.output_dir)
148148
_logger.info(
149149
"using the following configuration:\n"
150-
f"{yaml.dump(config.training.model_dump(mode='json'))}"
150+
f"{yaml.dump(config.training.model_dump(mode='json', exclude_none=True))}"
151151
)
152152

153+
if config.training.task is None:
154+
raise ValueError("task must be set in training configuration")
155+
153156
train_categorical_model_(
154157
config=config.training, advanced=config.advanced_config
155158
)
@@ -163,19 +166,21 @@ def _run_cli(args: argparse.Namespace) -> None:
163166
_add_file_handle_(_logger, output_dir=config.deployment.output_dir)
164167
_logger.info(
165168
"using the following configuration:\n"
166-
f"{yaml.dump(config.deployment.model_dump(mode='json'))}"
169+
f"{yaml.dump(config.deployment.model_dump(mode='json', exclude_none=True))}"
167170
)
168171
deploy_categorical_model_(
169172
output_dir=config.deployment.output_dir,
170173
checkpoint_paths=config.deployment.checkpoint_paths,
171174
clini_table=config.deployment.clini_table,
172175
slide_table=config.deployment.slide_table,
173176
feature_dir=config.deployment.feature_dir,
174-
ground_truth_label=config.deployment.ground_truth_label,
175177
patient_label=config.deployment.patient_label,
176178
filename_label=config.deployment.filename_label,
177179
num_workers=config.deployment.num_workers,
178180
accelerator=config.deployment.accelerator,
181+
ground_truth_label=config.deployment.ground_truth_label,
182+
time_label=config.deployment.time_label,
183+
status_label=config.deployment.status_label,
179184
)
180185

181186
case "crossval":
@@ -184,10 +189,13 @@ def _run_cli(args: argparse.Namespace) -> None:
184189
if config.crossval is None:
185190
raise ValueError("no crossval configuration supplied")
186191

192+
if config.crossval.task is None:
193+
raise ValueError("task must be set in crossval configuration")
194+
187195
_add_file_handle_(_logger, output_dir=config.crossval.output_dir)
188196
_logger.info(
189197
"using the following configuration:\n"
190-
f"{yaml.dump(config.crossval.model_dump(mode='json'))}"
198+
f"{yaml.dump(config.crossval.model_dump(mode='json', exclude_none=True))}"
191199
)
192200

193201
categorical_crossval_(
@@ -204,13 +212,17 @@ def _run_cli(args: argparse.Namespace) -> None:
204212
_add_file_handle_(_logger, output_dir=config.statistics.output_dir)
205213
_logger.info(
206214
"using the following configuration:\n"
207-
f"{yaml.dump(config.statistics.model_dump(mode='json'))}"
215+
f"{yaml.dump(config.statistics.model_dump(mode='json', exclude_none=True))}"
208216
)
217+
209218
compute_stats_(
219+
task=config.statistics.task,
210220
output_dir=config.statistics.output_dir,
211221
pred_csvs=config.statistics.pred_csvs,
212222
ground_truth_label=config.statistics.ground_truth_label,
213223
true_class=config.statistics.true_class,
224+
time_label=config.statistics.time_label,
225+
status_label=config.statistics.status_label,
214226
)
215227

216228
case "heatmaps":

src/stamp/config.yaml

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,16 @@ crossval:
6868
# are ignored. NOTE: Don't forget to add the .h5 file extension.
6969
slide_table: "/path/of/slide.csv"
7070

71+
# Task to infer (classification, regression, survival)
72+
task: "classification"
73+
7174
# Name of the column from the clini table to train on.
7275
ground_truth_label: "KRAS"
7376

77+
# For survival (should be status and follow-up days columns in clini table)
78+
# status_label: "event"
79+
# time_label: "time"
80+
7481
# Optional settings:
7582
patient_label: "PATIENT"
7683
filename_label: "FILENAME"
@@ -118,9 +125,16 @@ training:
118125
# are ignored. NOTE: Don't forget to add the .h5 file extension.
119126
slide_table: "/path/of/slide.csv"
120127

128+
# Task to infer (classification, regression, survival)
129+
task: "classification"
130+
121131
# Name of the column from the clini table to train on.
122132
ground_truth_label: "KRAS"
123133

134+
# For survival (should be status and follow-up days columns in clini table)
135+
# status_label: "event"
136+
# time_label: "time"
137+
124138
# Optional settings:
125139

126140
# The categories occurring in the target label column of the clini table.
@@ -156,9 +170,16 @@ deployment:
156170
# paths are ignored. NOTE: Don't forget to add the .h5 file extension.
157171
slide_table: "/path/of/slide.csv"
158172

173+
# Task to infer (classification, regression, survival)
174+
task: "classification"
175+
159176
# Name of the column from the clini to compare predictions to.
160177
ground_truth_label: "KRAS"
161178

179+
# For survival (should be status and follow-up days columns in clini table)
180+
# status_label: "event"
181+
# time_label: "time"
182+
162183
patient_label: "PATIENT"
163184
filename_label: "FILENAME"
164185

@@ -174,13 +195,20 @@ deployment:
174195
statistics:
175196
output_dir: "/path/to/save/files/to"
176197

198+
# Task to infer (classification, regression, survival)
199+
task: "classification"
200+
177201
# Name of the target label.
178202
ground_truth_label: "KRAS"
179203

180204
# A lot of the statistics are computed "one-vs-all", i.e. there needs to be
181205
# a positive class to calculate the statistics for.
182206
true_class: "mutated"
183207

208+
# For survival (should be status and follow-up days columns in clini table)
209+
# status_label: "event"
210+
# time_label: "time"
211+
184212
# The patient predictions to generate the statistics from.
185213
# For a single deployment, it could look like this:
186214
pred_csvs:
@@ -277,8 +305,7 @@ patient_encoding:
277305

278306

279307
advanced_config:
280-
# Optional random seed
281-
# seed: 42
308+
seed: 42
282309
max_epochs: 32
283310
patience: 16
284311
batch_size: 64
@@ -291,12 +318,10 @@ advanced_config:
291318
# Determines the initial learning rate via initial_lr = max_lr/div_factor
292319
max_lr: 1e-4
293320
div_factor: 25.
294-
# Select a model. Not working yet, added for future support.
295-
# Now it uses a ViT for tile features and a MLP for patient features.
296-
#model_name: "vit"
321+
# Select a model regardless of task
322+
model_name: "vit" # or mlp, trans_mil
297323

298324
model_params:
299-
# Tile-level training models:
300325
vit: # Vision Transformer
301326
dim_model: 512
302327
dim_feedforward: 512
@@ -306,7 +331,9 @@ advanced_config:
306331
# Experimental feature: Use ALiBi positional embedding
307332
use_alibi: false
308333

309-
# Patient-level training models:
334+
trans_mil: # https://arxiv.org/abs/2106.00908
335+
dim_hidden: 512
336+
310337
mlp: # Multilayer Perceptron
311338
dim_hidden: 512
312339
num_layers: 2

src/stamp/encoding/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def init_slide_encoder_(
5454

5555
selected_encoder: Encoder = Gigapath()
5656

57-
case EncoderName.CHIEF:
57+
case EncoderName.CHIEF_CTRANSPATH:
5858
from stamp.encoding.encoder.chief import CHIEF
5959

6060
selected_encoder: Encoder = CHIEF()
@@ -140,7 +140,7 @@ def init_patient_encoder_(
140140

141141
selected_encoder: Encoder = Gigapath()
142142

143-
case EncoderName.CHIEF:
143+
case EncoderName.CHIEF_CTRANSPATH:
144144
from stamp.encoding.encoder.chief import CHIEF
145145

146146
selected_encoder: Encoder = CHIEF()

src/stamp/encoding/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class EncoderName(StrEnum):
1010
COBRA = "cobra"
1111
EAGLE = "eagle"
12-
CHIEF = "chief"
12+
CHIEF_CTRANSPATH = "chief"
1313
TITAN = "titan"
1414
GIGAPATH = "gigapath"
1515
MADELEINE = "madeleine"

0 commit comments

Comments
 (0)