1212from matplotlib .axes import Axes
1313from matplotlib .figure import Figure
1414from matplotlib .patches import Patch
15+ from packaging .version import Version
1516from PIL import Image
1617from torch import Tensor
1718from torch .func import jacrev # pyright: ignore[reportPrivateImportUsage]
2223from stamp .preprocessing import supported_extensions
2324from stamp .preprocessing .tiling import get_slide_mpp_
2425from stamp .types import DeviceLikeType , Microns , SlideMPP , TilePixels
25- from packaging .version import Version
26-
2726
2827_logger = logging .getLogger ("stamp" )
2928
@@ -33,25 +32,26 @@ def _gradcam_per_category(
3332 feats : Float [Tensor , "tile feat" ],
3433 coords : Float [Tensor , "tile 2" ],
3534) -> Float [Tensor , "tile category" ]:
36- feat = - 1 # feats dimension
35+ feat_dim = - 1
3736
38- return (
37+ cam = (
3938 (
4039 feats
4140 * jacrev (
42- lambda bags : torch .softmax (
43- model .forward (
44- bags = bags .unsqueeze (0 ),
45- coords = coords .unsqueeze (0 ),
46- mask = None ,
47- ),
48- dim = 1 ,
41+ lambda bags : model .forward (
42+ bags = bags .unsqueeze (0 ),
43+ coords = coords .unsqueeze (0 ),
44+ mask = None ,
4945 ).squeeze (0 )
5046 )(feats )
5147 )
52- .mean (feat ) # type: ignore
48+ .mean (feat_dim ) # type: ignore
5349 .abs ()
54- ).permute (- 1 , - 2 )
50+ )
51+
52+ cam = torch .softmax (cam , dim = - 1 )
53+
54+ return cam .permute (- 1 , - 2 )
5555
5656
5757def _vals_to_im (
0 commit comments