Skip to content

Commit 0e8aa76

Browse files
committed
release orthogonal regularization loss, from https://arxiv.org/abs/2112.00384
1 parent aac3aa7 commit 0e8aa76

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,31 @@ x = torch.randn(1, 1024, 256)
126126
quantized, indices, commit_loss = vq(x)
127127
```
128128

129+
## Orthogonal regularization loss
130+
131+
VQ-VAE / VQ-GAN is quickly gaining popularity. A <a href="https://arxiv.org/abs/2112.00384">recent paper</a> proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
132+
133+
You can use this feature by simply setting the `orthogonal_reg_weight` to be greater than `0`, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
134+
135+
```python
136+
import torch
137+
from vector_quantize_pytorch import VectorQuantize
138+
139+
vq = VectorQuantize(
140+
dim = 256,
141+
codebook_size = 256,
142+
orthogonal_reg_weight = 10 # in paper, they recommended a value of 10
143+
)
144+
145+
x = torch.randn(1, 1024, 256)
146+
quantized, indices, loss = vq(x)
147+
148+
# loss now contains the orthogonal regularization loss with the weight as assigned
149+
```
150+
151+
129152
## Todo
130153

131-
- [ ] add orthogonality loss on codebook, from https://arxiv.org/abs/2112.00384
132154
- [ ] allow for multi-headed codebooks, from https://openreview.net/forum?id=GxjCYmQAody
133155

134156
## Citations

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.11',
6+
version = '0.4.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
5757

5858
return means, bins
5959

60+
# regularization losses
61+
62+
def orthgonal_loss_fn(t):
63+
# eq (2) from https://arxiv.org/abs/2112.00384
64+
n = t.shape[0]
65+
normed_codes = l2norm(t)
66+
identity = torch.eye(n, device = t.device)
67+
cosine_sim = einsum('i d, j d -> i j', normed_codes, normed_codes)
68+
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
69+
6070
# distance types
6171

6272
class EuclideanCodebook(nn.Module):
@@ -244,6 +254,7 @@ def __init__(
244254
codebook_dim = None,
245255
decay = 0.8,
246256
commitment = 1.,
257+
orthogonal_reg_weight = 0.,
247258
eps = 1e-5,
248259
kmeans_init = False,
249260
kmeans_iters = 10,
@@ -263,6 +274,7 @@ def __init__(
263274

264275
self.eps = eps
265276
self.commitment = commitment
277+
self.orthogonal_reg_weight = orthogonal_reg_weight
266278

267279
codebook_class = EuclideanCodebook if not use_cosine_sim \
268280
else CosineSimCodebook
@@ -285,6 +297,8 @@ def codebook(self):
285297
return self._codebook.embed
286298

287299
def forward(self, x):
300+
device, codebook_size = x.device, self.codebook_size
301+
288302
need_transpose = not self.channel_last
289303

290304
if need_transpose:
@@ -295,14 +309,22 @@ def forward(self, x):
295309
quantize, embed_ind = self._codebook(x)
296310

297311
if self.training:
298-
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
299312
quantize = x + (quantize - x).detach()
300-
else:
301-
commit_loss = torch.tensor([0.], device = x.device)
313+
314+
loss = torch.tensor([0.], device = device)
315+
316+
if self.training:
317+
if self.commitment > 0:
318+
commit_loss = F.mse_loss(quantize.detach(), x)
319+
loss = loss + commit_loss * self.commitment
320+
321+
if self.orthogonal_reg_weight > 0:
322+
orthogonal_reg_loss = orthgonal_loss_fn(self.codebook)
323+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
302324

303325
quantize = self.project_out(quantize)
304326

305327
if need_transpose:
306328
quantize = rearrange(quantize, 'b d n -> b n d')
307329

308-
return quantize, embed_ind, commit_loss
330+
return quantize, embed_ind, loss

0 commit comments

Comments
 (0)