|
2 | 2 | In parts from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py |
3 | 3 | """ |
4 | 4 |
|
5 | | -# TODO implement masking |
6 | | - |
7 | | -from typing import Iterable, cast |
| 5 | +from collections.abc import Iterable |
| 6 | +from typing import cast |
8 | 7 |
|
9 | 8 | import torch |
10 | 9 | from beartype import beartype |
11 | 10 | from einops import repeat |
12 | | -from jaxtyping import Float, jaxtyped |
| 11 | +from jaxtyping import Bool, Float, jaxtyped |
13 | 12 | from torch import Tensor, nn |
14 | 13 |
|
15 | 14 |
|
@@ -42,10 +41,31 @@ def __init__( |
42 | 41 |
|
43 | 42 | @jaxtyped(typechecker=beartype) |
44 | 43 | def forward( |
45 | | - self, x: Float[Tensor, "batch sequence proj_feature"] |
| 44 | + self, |
| 45 | + x: Float[Tensor, "batch sequence proj_feature"], |
| 46 | + *, |
| 47 | + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, |
46 | 48 | ) -> Float[Tensor, "batch sequence proj_feature"]: |
| 49 | + """ |
| 50 | + Args: |
| 51 | + attn_mask: |
| 52 | + Which of the features to ignore during self-attention. |
| 53 | + `attn_mask[b,q,k] == False` means that |
| 54 | + query `q` of batch `b` can attend to key `k`. |
| 55 | + If `attn_mask` is `None`, all tokens can attend to all others. |
| 56 | + """ |
47 | 57 | x = self.norm(x) |
48 | | - attn_output, _ = self.mhsa(x, x, x, need_weights=False) |
| 58 | + attn_output, _ = self.mhsa( |
| 59 | + x, |
| 60 | + x, |
| 61 | + x, |
| 62 | + need_weights=False, |
| 63 | + attn_mask=( |
| 64 | + attn_mask.repeat(self.mhsa.num_heads, 1, 1) |
| 65 | + if attn_mask is not None |
| 66 | + else None |
| 67 | + ), |
| 68 | + ) |
49 | 69 | return attn_output |
50 | 70 |
|
51 | 71 |
|
@@ -83,10 +103,13 @@ def __init__( |
83 | 103 |
|
84 | 104 | @jaxtyped(typechecker=beartype) |
85 | 105 | def forward( |
86 | | - self, x: Float[Tensor, "batch sequence proj_feature"] |
| 106 | + self, |
| 107 | + x: Float[Tensor, "batch sequence proj_feature"], |
| 108 | + *, |
| 109 | + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, |
87 | 110 | ) -> Float[Tensor, "batch sequence proj_feature"]: |
88 | 111 | for attn, ff in cast(Iterable[tuple[nn.Module, nn.Module]], self.layers): |
89 | | - x_attn = attn(x) |
| 112 | + x_attn = attn(x, attn_mask=attn_mask) |
90 | 113 | x = x_attn + x |
91 | 114 | x = ff(x) + x |
92 | 115 |
|
@@ -127,18 +150,36 @@ def __init__( |
127 | 150 |
|
128 | 151 | @jaxtyped(typechecker=beartype) |
129 | 152 | def forward( |
130 | | - self, bags: Float[Tensor, "batch tile feature"] |
| 153 | + self, |
| 154 | + bags: Float[Tensor, "batch tile feature"], |
| 155 | + *, |
| 156 | + mask: Bool[Tensor, "batch tile"] | None, |
131 | 157 | ) -> Float[Tensor, "batch logit"]: |
132 | 158 | batch_size, _n_tiles, _n_features = bags.shape |
133 | 159 |
|
134 | | - # map input sequence to latent space of TransMIL |
| 160 | + # Map input sequence to latent space of TransMIL |
135 | 161 | bags = self.project_features(bags) |
136 | 162 |
|
| 163 | + # Prepend a class token to every bag, |
| 164 | + # include it in the mask. |
| 165 | + # TODO should the tiles be able to refer to the class token? Test! |
137 | 166 | cls_tokens = repeat(self.class_token, "d -> b 1 d", b=batch_size) |
138 | | - bags = torch.cat((cls_tokens, bags), dim=1) |
139 | | - |
140 | | - bags = self.transformer(bags) |
141 | | - |
142 | | - bags = bags[:, 0] # only take class token |
| 167 | + bags = torch.cat([cls_tokens, bags], dim=1) |
| 168 | + if mask is not None: |
| 169 | + mask_with_class_token = torch.cat( |
| 170 | + [torch.zeros(mask.shape[0], 1).type_as(mask), mask], dim=1 |
| 171 | + ) |
| 172 | + square_attn_mask = torch.einsum( |
| 173 | + "bq,bk->bqk", mask_with_class_token, mask_with_class_token |
| 174 | + ) |
| 175 | + # Don't allow other tiles to reference the class token |
| 176 | + square_attn_mask[:, 1:, 0] = True |
| 177 | + |
| 178 | + bags = self.transformer(bags, attn_mask=square_attn_mask) |
| 179 | + else: |
| 180 | + bags = self.transformer(bags, attn_mask=None) |
| 181 | + |
| 182 | + # Only take class token |
| 183 | + bags = bags[:, 0] |
143 | 184 |
|
144 | 185 | return self.mlp_head(bags) |
0 commit comments