Skip to content

Commit dd5e4a3

Browse files
committed
allow for passing in image feature maps, taking care of transposes
1 parent 1c29b92 commit dd5e4a3

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,16 @@ from vector_quantize_pytorch import VectorQuantize
139139
vq = VectorQuantize(
140140
dim = 256,
141141
codebook_size = 256,
142-
orthogonal_reg_weight = 10 # in paper, they recommended a value of 10
142+
accept_image_fmap = True, # set this true to be able to pass in an image feature map
143+
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
143144
)
144145

145-
x = torch.randn(1, 1024, 256)
146-
quantized, indices, loss = vq(x)
146+
img_fmap = torch.randn(1, 256, 32, 32)
147+
quantized, indices, loss = vq(x) # (1, 256, 32, 32), (1, 32, 32), (1,)
147148

148149
# loss now contains the orthogonal regularization loss with the weight as assigned
149150
```
150151

151-
152152
## Todo
153153

154154
- [ ] allow for multi-headed codebooks, from https://openreview.net/forum?id=GxjCYmQAody

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.4.2',
6+
version = '0.4.3',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def __init__(
261261
use_cosine_sim = False,
262262
threshold_ema_dead_code = 0,
263263
channel_last = True,
264+
accept_image_fmap = False,
264265
commitment = 1. # deprecate in next version, turn off by default
265266
):
266267
super().__init__()
@@ -291,19 +292,25 @@ def __init__(
291292
)
292293

293294
self.codebook_size = codebook_size
295+
296+
self.accept_image_fmap = accept_image_fmap
294297
self.channel_last = channel_last
295298

296299
@property
297300
def codebook(self):
298301
return self._codebook.embed
299302

300303
def forward(self, x):
301-
device, codebook_size = x.device, self.codebook_size
304+
shape, device, codebook_size = x.shape, x.device, self.codebook_size
305+
306+
need_transpose = not self.channel_last and not self.accept_image_fmap
302307

303-
need_transpose = not self.channel_last
308+
if self.accept_image_fmap:
309+
height, width = x.shape[-2:]
310+
x = rearrange(x, 'b c h w -> b (h w) c')
304311

305312
if need_transpose:
306-
x = rearrange(x, 'b n d -> b d n')
313+
x = rearrange(x, 'b d n -> b n d')
307314

308315
x = self.project_in(x)
309316

@@ -326,6 +333,10 @@ def forward(self, x):
326333
quantize = self.project_out(quantize)
327334

328335
if need_transpose:
329-
quantize = rearrange(quantize, 'b d n -> b n d')
336+
quantize = rearrange(quantize, 'b n d -> b d n')
337+
338+
if self.accept_image_fmap:
339+
quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width)
340+
embed_ind = rearrange(embed_ind, 'b (h w) -> b h w', h = height, w = width)
330341

331342
return quantize, embed_ind, loss

0 commit comments

Comments
 (0)