11"""Helper classes to manage pytorch data."""
22
33import logging
4- from collections .abc import Iterable , Mapping , Sequence
4+ from collections .abc import Callable , Iterable , Mapping , Sequence
55from dataclasses import KW_ONLY , dataclass
66from itertools import groupby
77from pathlib import Path
@@ -63,6 +63,7 @@ def dataloader_from_patient_data(
6363 batch_size : int ,
6464 shuffle : bool ,
6565 num_workers : int ,
66+ transform : Callable [[Tensor ], Tensor ] | None ,
6667) -> tuple [DataLoader [tuple [Bags , BagSizes , EncodedTargets ]], Sequence [Category ]]:
6768 """Creates a dataloader from patient data, encoding the ground truths.
6869
@@ -81,6 +82,7 @@ def dataloader_from_patient_data(
8182 bags = [patient .feature_files for patient in patient_data ],
8283 bag_size = bag_size ,
8384 ground_truths = one_hot ,
85+ transform = transform ,
8486 )
8587
8688 return (
@@ -133,6 +135,8 @@ class BagDataset(Dataset[tuple[_Bag, BagSize, _EncodedTarget]]):
133135 ground_truths : Bool [Tensor , "index category_is_hot" ]
134136 """The ground truth for each bag, one-hot encoded."""
135137
138+ transform : Callable [[Tensor ], Tensor ] | None
139+
136140 def __post_init__ (self ) -> None :
137141 if len (self .bags ) != len (self .ground_truths ):
138142 raise ValueError (
@@ -152,8 +156,11 @@ def __getitem__(self, index: int) -> tuple[_Bag, BagSize, _EncodedTarget]:
152156 )
153157 feats = torch .concat (feats ).float ()
154158
159+ if self .transform is not None :
160+ feats = self .transform (feats )
161+
155162 # Sample a subset, if required
156- if self .bag_size :
163+ if self .bag_size is not None :
157164 return (
158165 * _to_fixed_size_bag (feats , bag_size = self .bag_size ),
159166 self .ground_truths [index ],
@@ -166,7 +173,7 @@ def __getitem__(self, index: int) -> tuple[_Bag, BagSize, _EncodedTarget]:
166173 )
167174
168175
169- def _to_fixed_size_bag (bag : _Bag , bag_size : BagSize = 512 ) -> tuple [_Bag , BagSize ]:
176+ def _to_fixed_size_bag (bag : _Bag , bag_size : BagSize ) -> tuple [_Bag , BagSize ]:
170177 """Samples a fixed-size bag of tiles from an arbitrary one.
171178
172179 If the original bag did not have enough tiles,
0 commit comments