Skip to content

Commit c7498f4

Browse files
committed
Implement masking
1 parent 03ae47b commit c7498f4

File tree

4 files changed

+86
-39
lines changed

4 files changed

+86
-39
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "stamp"
3-
version = "2.0.0-dev7"
3+
version = "2.0.0-dev8"
44
authors = [
55
{ name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" },
66
{ name = "Marko van Treeck", email = "markovantreeck@gmail.com" },

src/stamp/modeling/lightning_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import lightning
77
import numpy as np
8+
import torch
89
from jaxtyping import Float
910
from packaging.version import Version
1011
from torch import Tensor, nn, optim
@@ -71,7 +72,7 @@ def __init__(
7172
# Check if version is compatible.
7273
# This should only happen when the model is loaded,
7374
# otherwise the default value will make these checks pass.
74-
if stamp_version < Version("2.0.0.dev1"):
75+
if stamp_version < Version("2.0.0.dev8"):
7576
# Update this as we change our model in incompatible ways!
7677
raise ValueError(
7778
f"model has been built with stamp version {stamp_version} "
@@ -112,9 +113,14 @@ def _step(
112113
) -> Loss:
113114
_ = batch_idx # unused
114115

115-
bags, _, targets = batch
116+
bags, bag_sizes, targets = batch
116117

117-
logits = self.vision_transformer(bags)
118+
max_possible_bag_size = bags.size(1)
119+
mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze(
120+
0
121+
).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1)
122+
123+
logits = self.vision_transformer(bags, mask=mask)
118124

119125
loss = nn.functional.cross_entropy(
120126
logits, targets.type_as(logits), weight=self.class_weights.type_as(logits)

src/stamp/modeling/vision_transformer.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
In parts from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
33
"""
44

5-
# TODO implement masking
6-
7-
from typing import Iterable, cast
5+
from collections.abc import Iterable
6+
from typing import cast
87

98
import torch
109
from beartype import beartype
1110
from einops import repeat
12-
from jaxtyping import Float, jaxtyped
11+
from jaxtyping import Bool, Float, jaxtyped
1312
from torch import Tensor, nn
1413

1514

@@ -42,10 +41,31 @@ def __init__(
4241

4342
@jaxtyped(typechecker=beartype)
4443
def forward(
45-
self, x: Float[Tensor, "batch sequence proj_feature"]
44+
self,
45+
x: Float[Tensor, "batch sequence proj_feature"],
46+
*,
47+
attn_mask: Bool[Tensor, "batch sequence sequence"] | None,
4648
) -> Float[Tensor, "batch sequence proj_feature"]:
49+
"""
50+
Args:
51+
attn_mask:
52+
Which of the features to ignore during self-attention.
53+
`attn_mask[b,q,k] == False` means that
54+
query `q` of batch `b` can attend to key `k`.
55+
If `attn_mask` is `None`, all tokens can attend to all others.
56+
"""
4757
x = self.norm(x)
48-
attn_output, _ = self.mhsa(x, x, x, need_weights=False)
58+
attn_output, _ = self.mhsa(
59+
x,
60+
x,
61+
x,
62+
need_weights=False,
63+
attn_mask=(
64+
attn_mask.repeat(self.mhsa.num_heads, 1, 1)
65+
if attn_mask is not None
66+
else None
67+
),
68+
)
4969
return attn_output
5070

5171

@@ -83,10 +103,13 @@ def __init__(
83103

84104
@jaxtyped(typechecker=beartype)
85105
def forward(
86-
self, x: Float[Tensor, "batch sequence proj_feature"]
106+
self,
107+
x: Float[Tensor, "batch sequence proj_feature"],
108+
*,
109+
attn_mask: Bool[Tensor, "batch sequence sequence"] | None,
87110
) -> Float[Tensor, "batch sequence proj_feature"]:
88111
for attn, ff in cast(Iterable[tuple[nn.Module, nn.Module]], self.layers):
89-
x_attn = attn(x)
112+
x_attn = attn(x, attn_mask=attn_mask)
90113
x = x_attn + x
91114
x = ff(x) + x
92115

@@ -127,18 +150,36 @@ def __init__(
127150

128151
@jaxtyped(typechecker=beartype)
129152
def forward(
130-
self, bags: Float[Tensor, "batch tile feature"]
153+
self,
154+
bags: Float[Tensor, "batch tile feature"],
155+
*,
156+
mask: Bool[Tensor, "batch tile"] | None,
131157
) -> Float[Tensor, "batch logit"]:
132158
batch_size, _n_tiles, _n_features = bags.shape
133159

134-
# map input sequence to latent space of TransMIL
160+
# Map input sequence to latent space of TransMIL
135161
bags = self.project_features(bags)
136162

163+
# Prepend a class token to every bag,
164+
# include it in the mask.
165+
# TODO should the tiles be able to refer to the class token? Test!
137166
cls_tokens = repeat(self.class_token, "d -> b 1 d", b=batch_size)
138-
bags = torch.cat((cls_tokens, bags), dim=1)
139-
140-
bags = self.transformer(bags)
141-
142-
bags = bags[:, 0] # only take class token
167+
bags = torch.cat([cls_tokens, bags], dim=1)
168+
if mask is not None:
169+
mask_with_class_token = torch.cat(
170+
[torch.zeros(mask.shape[0], 1).type_as(mask), mask], dim=1
171+
)
172+
square_attn_mask = torch.einsum(
173+
"bq,bk->bqk", mask_with_class_token, mask_with_class_token
174+
)
175+
# Don't allow other tiles to reference the class token
176+
square_attn_mask[:, 1:, 0] = True
177+
178+
bags = self.transformer(bags, attn_mask=square_attn_mask)
179+
else:
180+
bags = self.transformer(bags, attn_mask=None)
181+
182+
# Only take class token
183+
bags = bags[:, 0]
143184

144185
return self.mlp_head(bags)

uv.lock

Lines changed: 20 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)