Skip to content

Commit db5dd6e

Browse files
committed
add measure number one for dead codebooks
1 parent d28d851 commit db5dd6e

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ x = torch.randn(1, 1024, 256)
6868
quantized, indices, commit_loss = residual_vq(x)
6969
```
7070

71+
## Increasing codebook usage
72+
73+
This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.
74+
75+
### Lower codebook dimension
76+
77+
The <a href="https://openreview.net/forum?id=pfNyExj7z2">Improved VQGAN paper</a> proposes to have the codebook in a lower dimension, and the encoder values is projected down, before being projected back to high dimensional on output. You can set this with the `codebook_dim` hyperparameter.
78+
79+
```python
80+
import torch
81+
from vector_quantize_pytorch import VectorQuantize
82+
83+
vq = VectorQuantize(
84+
dim = 256,
85+
codebook_size = 256,
86+
codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage
87+
)
88+
89+
x = torch.randn(1, 1024, 256)
90+
quantized, indices, commit_loss = vq(x)
91+
```
92+
7193
## Citations
7294

7395
```bibtex
@@ -91,3 +113,14 @@ quantized, indices, commit_loss = residual_vq(x)
91113
primaryClass = {cs.SD}
92114
}
93115
```
116+
117+
```bibtex
118+
@inproceedings{anonymous2022vectorquantized,
119+
title = {Vector-quantized Image Modeling with Improved {VQGAN}},
120+
author = {Anonymous},
121+
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
122+
year = {2022},
123+
url = {https://openreview.net/forum?id=pfNyExj7z2},
124+
note = {under review}
125+
}
126+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.0',
6+
version = '0.3.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,26 @@ def __init__(
5252
eps = 1e-5,
5353
n_embed = None,
5454
kmeans_init = False,
55-
kmeans_iters = 10
55+
kmeans_iters = 10,
56+
codebook_dim = None
5657
):
5758
super().__init__()
5859
n_embed = default(n_embed, codebook_size)
5960

6061
self.dim = dim
6162
self.n_embed = n_embed
63+
64+
codebook_dim = default(codebook_dim, dim)
65+
requires_projection = codebook_dim != dim
66+
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
67+
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
68+
6269
self.decay = decay
6370
self.eps = eps
6471
self.commitment = commitment
6572

6673
init_fn = torch.randn if not kmeans_init else torch.zeros
67-
embed = init_fn(dim, n_embed)
74+
embed = init_fn(codebook_dim, n_embed)
6875

6976
self.kmeans_iters = kmeans_iters
7077
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
@@ -83,11 +90,13 @@ def init_embed_(self, data):
8390
self.initted.data.copy_(torch.Tensor([True]))
8491

8592
def forward(self, input):
93+
input = self.project_in(input)
94+
8695
if not self.initted:
8796
self.init_embed_(input)
8897

8998
dtype = input.dtype
90-
flatten = input.reshape(-1, self.dim)
99+
flatten = rearrange(input, '... d -> (...) d')
91100
dist = (
92101
flatten.pow(2).sum(1, keepdim=True)
93102
- 2 * flatten @ self.embed
@@ -112,4 +121,5 @@ def forward(self, input):
112121
commit_loss = F.mse_loss(quantize.detach(), input) * self.commitment
113122
quantize = input + (quantize - input).detach()
114123

124+
quantize = self.project_out(quantize)
115125
return quantize, embed_ind, commit_loss

0 commit comments

Comments
 (0)