Skip to content

Commit ff96f8c

Browse files
authored
Merge pull request #128 from KatherLab/dev/logit-heatmap
Adjust heatmap function
2 parents d48e8b9 + 7aff019 commit ff96f8c

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/stamp/heatmaps/__init__.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from matplotlib.axes import Axes
1313
from matplotlib.figure import Figure
1414
from matplotlib.patches import Patch
15+
from packaging.version import Version
1516
from PIL import Image
1617
from torch import Tensor
1718
from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage]
@@ -22,8 +23,6 @@
2223
from stamp.preprocessing import supported_extensions
2324
from stamp.preprocessing.tiling import get_slide_mpp_
2425
from stamp.types import DeviceLikeType, Microns, SlideMPP, TilePixels
25-
from packaging.version import Version
26-
2726

2827
_logger = logging.getLogger("stamp")
2928

@@ -33,25 +32,26 @@ def _gradcam_per_category(
3332
feats: Float[Tensor, "tile feat"],
3433
coords: Float[Tensor, "tile 2"],
3534
) -> Float[Tensor, "tile category"]:
36-
feat = -1 # feats dimension
35+
feat_dim = -1
3736

38-
return (
37+
cam = (
3938
(
4039
feats
4140
* jacrev(
42-
lambda bags: torch.softmax(
43-
model.forward(
44-
bags=bags.unsqueeze(0),
45-
coords=coords.unsqueeze(0),
46-
mask=None,
47-
),
48-
dim=1,
41+
lambda bags: model.forward(
42+
bags=bags.unsqueeze(0),
43+
coords=coords.unsqueeze(0),
44+
mask=None,
4945
).squeeze(0)
5046
)(feats)
5147
)
52-
.mean(feat) # type: ignore
48+
.mean(feat_dim) # type: ignore
5349
.abs()
54-
).permute(-1, -2)
50+
)
51+
52+
cam = torch.softmax(cam, dim=-1)
53+
54+
return cam.permute(-1, -2)
5555

5656

5757
def _vals_to_im(

0 commit comments

Comments
 (0)