File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change 33setup (
44 name = 'vector_quantize_pytorch' ,
55 packages = find_packages (),
6- version = '0.3.7 ' ,
6+ version = '0.3.8 ' ,
77 license = 'MIT' ,
88 description = 'Vector Quantization - Pytorch' ,
99 author = 'Phil Wang' ,
Original file line number Diff line number Diff line change @@ -242,7 +242,8 @@ def __init__(
242242 kmeans_init = False ,
243243 kmeans_iters = 10 ,
244244 use_cosine_sim = False ,
245- threshold_ema_dead_code = 0
245+ threshold_ema_dead_code = 0 ,
246+ channel_last = True
246247 ):
247248 super ().__init__ ()
248249 n_embed = default (n_embed , codebook_size )
@@ -271,12 +272,18 @@ def __init__(
271272 )
272273
273274 self .codebook_size = codebook_size
275+ self .channel_last = channel_last
274276
275277 @property
276278 def codebook (self ):
277279 return self ._codebook .codebook
278280
279281 def forward (self , x ):
282+ need_transpose = not self .channel_last
283+
284+ if need_transpose :
285+ x = rearrange (x , 'b n d -> b d n' )
286+
280287 x = self .project_in (x )
281288
282289 quantize , embed_ind = self ._codebook (x )
@@ -288,4 +295,8 @@ def forward(self, x):
288295 quantize = x + (quantize - x ).detach ()
289296
290297 quantize = self .project_out (quantize )
298+
299+ if need_transpose :
300+ quantize = rearrange (quantize , 'b n d -> b d n' )
301+
291302 return quantize , embed_ind , commit_loss
You can’t perform that action at this time.
0 commit comments