Skip to content

Commit 707f3e3

Browse files
committed
Add feature precision variation transform
1 parent e778c81 commit 707f3e3

File tree

12 files changed

+104
-13
lines changed

12 files changed

+104
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "stamp"
3-
version = "2.0.0-dev6"
3+
version = "2.0.0-dev7"
44
authors = [
55
{ name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" },
66
{ name = "Marko van Treeck", email = "markovantreeck@gmail.com" },

src/stamp/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _run_cli(args: argparse.Namespace) -> None:
106106
max_epochs=config.training.max_epochs,
107107
patience=config.training.patience,
108108
accelerator=config.training.accelerator,
109+
use_vary_precision_transform=config.training.use_vary_precision_transform,
109110
)
110111

111112
case "deploy":
@@ -156,11 +157,13 @@ def _run_cli(args: argparse.Namespace) -> None:
156157
# Dataset and -loader parameters
157158
bag_size=config.crossval.bag_size,
158159
num_workers=config.crossval.num_workers,
159-
# crossval paramenters
160+
# Crossval paramenters
160161
batch_size=config.crossval.batch_size,
161162
max_epochs=config.crossval.max_epochs,
162163
patience=config.crossval.patience,
163164
accelerator=config.crossval.accelerator,
165+
# Experimental Features
166+
use_vary_precision_transform=config.crossval.use_vary_precision_transform,
164167
)
165168

166169
case "statistics":

src/stamp/config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ training:
104104
# If unspecified, they will be inferred from the table itself.
105105
#categories: ["mutated", "wild type"]
106106

107+
# Experimental features:
108+
109+
# Please try uncommenting the settings below
110+
# and report if they improve / reduce model performance!
111+
112+
# Change the precision of features during training
113+
#use_vary_precision_transform: true
114+
107115

108116
deployment:
109117
output_dir: "/path/to/save/files/to"

src/stamp/modeling/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class TrainConfig(BaseModel):
3434
patience: int = 16
3535
accelerator: str = "gpu" if torch.cuda.is_available() else "cpu"
3636

37+
# Experimental features
38+
use_vary_precision_transform: bool = False
39+
3740

3841
class CrossvalConfig(TrainConfig):
3942
n_splits: int = Field(5, ge=2)

src/stamp/modeling/crossval.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from stamp.modeling.deploy import _predict, _to_prediction_df
2323
from stamp.modeling.lightning_model import LitVisionTransformer
2424
from stamp.modeling.train import setup_model_for_training, train_model_
25+
from stamp.modeling.transforms import VaryPrecisionTransform
2526

2627
__author__ = "Marko van Treeck"
2728
__copyright__ = "Copyright (C) 2024 Marko van Treeck"
@@ -57,6 +58,8 @@ def categorical_crossval_(
5758
max_epochs: int,
5859
patience: int,
5960
accelerator: str | Accelerator,
61+
# Experimental features
62+
use_vary_precision_transform: bool,
6063
) -> None:
6164
patient_to_ground_truth: Final[dict[PatientId, GroundTruth]] = (
6265
patient_to_ground_truth_from_clini_table_(
@@ -149,6 +152,11 @@ def categorical_crossval_(
149152
}
150153
)
151154
),
155+
train_transform=(
156+
VaryPrecisionTransform(min_fraction_bits=1)
157+
if use_vary_precision_transform
158+
else None
159+
),
152160
)
153161
model = train_model_(
154162
output_dir=split_dir,
@@ -203,4 +211,4 @@ def _get_splits(
203211
)
204212
]
205213
)
206-
return splits
214+
return splits

src/stamp/modeling/data.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Helper classes to manage pytorch data."""
22

33
import logging
4-
from collections.abc import Iterable, Mapping, Sequence
4+
from collections.abc import Callable, Iterable, Mapping, Sequence
55
from dataclasses import KW_ONLY, dataclass
66
from itertools import groupby
77
from pathlib import Path
@@ -63,6 +63,7 @@ def dataloader_from_patient_data(
6363
batch_size: int,
6464
shuffle: bool,
6565
num_workers: int,
66+
transform: Callable[[Tensor], Tensor] | None,
6667
) -> tuple[DataLoader[tuple[Bags, BagSizes, EncodedTargets]], Sequence[Category]]:
6768
"""Creates a dataloader from patient data, encoding the ground truths.
6869
@@ -81,6 +82,7 @@ def dataloader_from_patient_data(
8182
bags=[patient.feature_files for patient in patient_data],
8283
bag_size=bag_size,
8384
ground_truths=one_hot,
85+
transform=transform,
8486
)
8587

8688
return (
@@ -133,6 +135,8 @@ class BagDataset(Dataset[tuple[_Bag, BagSize, _EncodedTarget]]):
133135
ground_truths: Bool[Tensor, "index category_is_hot"]
134136
"""The ground truth for each bag, one-hot encoded."""
135137

138+
transform: Callable[[Tensor], Tensor] | None
139+
136140
def __post_init__(self) -> None:
137141
if len(self.bags) != len(self.ground_truths):
138142
raise ValueError(
@@ -152,8 +156,11 @@ def __getitem__(self, index: int) -> tuple[_Bag, BagSize, _EncodedTarget]:
152156
)
153157
feats = torch.concat(feats).float()
154158

159+
if self.transform is not None:
160+
feats = self.transform(feats)
161+
155162
# Sample a subset, if required
156-
if self.bag_size:
163+
if self.bag_size is not None:
157164
return (
158165
*_to_fixed_size_bag(feats, bag_size=self.bag_size),
159166
self.ground_truths[index],
@@ -166,7 +173,7 @@ def __getitem__(self, index: int) -> tuple[_Bag, BagSize, _EncodedTarget]:
166173
)
167174

168175

169-
def _to_fixed_size_bag(bag: _Bag, bag_size: BagSize = 512) -> tuple[_Bag, BagSize]:
176+
def _to_fixed_size_bag(bag: _Bag, bag_size: BagSize) -> tuple[_Bag, BagSize]:
170177
"""Samples a fixed-size bag of tiles from an arbitrary one.
171178
172179
If the original bag did not have enough tiles,

src/stamp/modeling/deploy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def _predict(
126126
batch_size=1,
127127
shuffle=False,
128128
num_workers=num_workers,
129+
transform=None,
129130
)
130131

131132
trainer = lightning.Trainer(
@@ -136,9 +137,7 @@ def _predict(
136137
predictions = torch.concat(
137138
cast(
138139
list[torch.Tensor],
139-
trainer.predict(
140-
cast(lightning.LightningModule, torch.compile(model)), test_dl
141-
),
140+
trainer.predict(model, test_dl),
142141
)
143142
)
144143

src/stamp/modeling/train.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import shutil
2-
from collections.abc import Mapping, Sequence
2+
from collections.abc import Callable, Mapping, Sequence
33
from pathlib import Path
44
from typing import cast
55

@@ -32,6 +32,7 @@
3232
EncodedTargets,
3333
LitVisionTransformer,
3434
)
35+
from stamp.modeling.transforms import VaryPrecisionTransform
3536

3637
__author__ = "Marko van Treeck"
3738
__copyright__ = "Copyright (C) 2024 Marko van Treeck"
@@ -56,6 +57,8 @@ def train_categorical_model_(
5657
max_epochs: int,
5758
patience: int,
5859
accelerator: str | Accelerator,
60+
# Experimental features
61+
use_vary_precision_transform: bool,
5962
) -> None:
6063
"""Trains a model.
6164
@@ -119,6 +122,11 @@ def train_categorical_model_(
119122
clini_table=clini_table,
120123
slide_table=slide_table,
121124
feature_dir=feature_dir,
125+
train_transform=(
126+
VaryPrecisionTransform(min_fraction_bits=1)
127+
if use_vary_precision_transform
128+
else None
129+
),
122130
)
123131
train_model_(
124132
output_dir=output_dir,
@@ -187,6 +195,7 @@ def setup_model_for_training(
187195
bag_size: int,
188196
batch_size: int,
189197
num_workers: int,
198+
train_transform: Callable[[torch.Tensor], torch.Tensor] | None,
190199
# Metadata, has no effect on model training
191200
ground_truth_label: PandasLabel,
192201
clini_table: Path,
@@ -225,6 +234,7 @@ def setup_model_for_training(
225234
batch_size=batch_size,
226235
shuffle=True,
227236
num_workers=num_workers,
237+
transform=train_transform,
228238
)
229239
del categories # Let's not accidentally reuse the original categories
230240
valid_dl, _ = dataloader_from_patient_data(
@@ -234,6 +244,7 @@ def setup_model_for_training(
234244
batch_size=1,
235245
shuffle=False,
236246
num_workers=num_workers,
247+
transform=None,
237248
)
238249
if overlap := set(train_patients) & set(valid_patients):
239250
raise RuntimeError(

src/stamp/modeling/transforms.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
from jaxtyping import Float
3+
4+
5+
def vary_precision(
6+
data: Float[torch.Tensor, "*dims"], *, min_fraction_bits: int
7+
) -> Float[torch.Tensor, "*dims"]:
8+
"""Randomly reduces the precision of the tensor's values."""
9+
if min_fraction_bits < 1:
10+
raise ValueError("min_fraction bits has to be at least 1")
11+
12+
if data.dtype == torch.float32:
13+
fraction_bits = 23
14+
mask_dtype = torch.int32
15+
elif data.dtype == torch.float16:
16+
fraction_bits = 10
17+
mask_dtype = torch.int16
18+
elif data.dtype == torch.bfloat16:
19+
fraction_bits = 7
20+
mask_dtype = torch.int16
21+
else:
22+
raise NotImplementedError(
23+
f"precision variation not implemented for {data.dtype}"
24+
)
25+
26+
no_of_bits_to_mask = torch.randint(0, fraction_bits - min_fraction_bits, data.shape)
27+
mask = (~0 << no_of_bits_to_mask).to(dtype=mask_dtype, device=data.device)
28+
augmented = (data.view(mask_dtype) & mask).view(data.dtype)
29+
return augmented
30+
31+
32+
class VaryPrecisionTransform:
33+
"""A transform randomly reducing the precision of its inputs."""
34+
35+
def __init__(self, *, min_fraction_bits: int = 1) -> None:
36+
self.min_fraction_bits = min_fraction_bits
37+
38+
def __call__(
39+
self, batch: Float[torch.Tensor, "*dims"]
40+
) -> Float[torch.Tensor, "*dims"]:
41+
return vary_precision(data=batch, min_fraction_bits=self.min_fraction_bits)

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def test_dataset(
7676
dim_feats: int = 34,
7777
batch_size: int = 2,
7878
) -> None:
79-
8079
ds = BagDataset(
8180
bags=[
8281
[_make_feature_file(torch.rand((12, dim_feats)))],
@@ -85,6 +84,7 @@ def test_dataset(
8584
],
8685
bag_size=bag_size,
8786
ground_truths=torch.rand(3, 4) > 0.5,
87+
transform=None,
8888
)
8989

9090
assert len(ds) == 3

0 commit comments

Comments
 (0)