File tree Expand file tree Collapse file tree 4 files changed +16
-9
lines changed Expand file tree Collapse file tree 4 files changed +16
-9
lines changed Original file line number Diff line number Diff line change @@ -18,9 +18,9 @@ from vector_quantize_pytorch import VectorQuantize
1818
1919vq = VectorQuantize(
2020 dim = 256 ,
21- n_embed = 512 , # size of the dictionary
22- decay = 0.8 , # the exponential moving average decay, lower means the dictionary will change faster
23- commitment = 1 . # the weight on the commitment loss
21+ codebook_size = 512 , # codebook size
22+ decay = 0.8 , # the exponential moving average decay, lower means the dictionary will change faster
23+ commitment = 1 . # the weight on the commitment loss
2424)
2525
2626x = torch.randn(1 , 1024 , 256 )
@@ -38,7 +38,7 @@ from vector_quantize_pytorch import ResidualVQ
3838residual_vq = ResidualVQ(
3939 dim = 256 ,
4040 num_quantizers = 8 , # specify number of quantizers
41- n_embed = 1024 , # codebook size
41+ codebook_size = 1024 , # codebook size
4242)
4343
4444x = torch.randn(1 , 1024 , 256 )
Original file line number Diff line number Diff line change 33setup (
44 name = 'vector_quantize_pytorch' ,
55 packages = find_packages (),
6- version = '0.2.0 ' ,
6+ version = '0.2.1 ' ,
77 license = 'MIT' ,
88 description = 'Vector Quantization - Pytorch' ,
99 author = 'Phil Wang' ,
Original file line number Diff line number Diff line change @@ -8,11 +8,10 @@ def __init__(
88 self ,
99 * ,
1010 num_quantizers ,
11- n_embed ,
1211 ** kwargs
1312 ):
1413 super ().__init__ ()
15- self .layers = nn .ModuleList ([VectorQuantize (n_embed = n_embed , ** kwargs ) for _ in range (num_quantizers )])
14+ self .layers = nn .ModuleList ([VectorQuantize (** kwargs ) for _ in range (num_quantizers )])
1615
1716 def forward (self , x ):
1817 quantized_out = 0.
Original file line number Diff line number Diff line change 22from torch import nn
33import torch .nn .functional as F
44
5+ def exists (val ):
6+ return val is not None
7+
8+ def default (val , d ):
9+ return val if exists (val ) else d
10+
511def ema_inplace (moving_avg , new , decay ):
612 moving_avg .data .mul_ (decay ).add_ (new , alpha = (1 - decay ))
713
@@ -12,12 +18,14 @@ class VectorQuantize(nn.Module):
1218 def __init__ (
1319 self ,
1420 dim ,
15- n_embed ,
21+ codebook_size ,
1622 decay = 0.8 ,
1723 commitment = 1. ,
18- eps = 1e-5
24+ eps = 1e-5 ,
25+ n_embed = None ,
1926 ):
2027 super ().__init__ ()
28+ n_embed = default (n_embed , codebook_size )
2129
2230 self .dim = dim
2331 self .n_embed = n_embed
You can’t perform that action at this time.
0 commit comments