Skip to content

Commit 381d574

Browse files
committed
make init param more intuitive, but leave n_embed backwards compatible
1 parent 284a671 commit 381d574

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ from vector_quantize_pytorch import VectorQuantize
1818

1919
vq = VectorQuantize(
2020
dim = 256,
21-
n_embed = 512, # size of the dictionary
22-
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
23-
commitment = 1. # the weight on the commitment loss
21+
codebook_size = 512, # codebook size
22+
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
23+
commitment = 1. # the weight on the commitment loss
2424
)
2525

2626
x = torch.randn(1, 1024, 256)
@@ -38,7 +38,7 @@ from vector_quantize_pytorch import ResidualVQ
3838
residual_vq = ResidualVQ(
3939
dim = 256,
4040
num_quantizers = 8, # specify number of quantizers
41-
n_embed = 1024, # codebook size
41+
codebook_size = 1024, # codebook size
4242
)
4343

4444
x = torch.randn(1, 1024, 256)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.2.0',
6+
version = '0.2.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/residual_vq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ def __init__(
88
self,
99
*,
1010
num_quantizers,
11-
n_embed,
1211
**kwargs
1312
):
1413
super().__init__()
15-
self.layers = nn.ModuleList([VectorQuantize(n_embed = n_embed, **kwargs) for _ in range(num_quantizers)])
14+
self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
1615

1716
def forward(self, x):
1817
quantized_out = 0.

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
from torch import nn
33
import torch.nn.functional as F
44

5+
def exists(val):
6+
return val is not None
7+
8+
def default(val, d):
9+
return val if exists(val) else d
10+
511
def ema_inplace(moving_avg, new, decay):
612
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
713

@@ -12,12 +18,14 @@ class VectorQuantize(nn.Module):
1218
def __init__(
1319
self,
1420
dim,
15-
n_embed,
21+
codebook_size,
1622
decay = 0.8,
1723
commitment = 1.,
18-
eps = 1e-5
24+
eps = 1e-5,
25+
n_embed = None,
1926
):
2027
super().__init__()
28+
n_embed = default(n_embed, codebook_size)
2129

2230
self.dim = dim
2331
self.n_embed = n_embed

0 commit comments

Comments
 (0)