Skip to content

Commit 1c29b92

Browse files
committed
get ready to deprecate and turn off commitment loss in next version
1 parent 49f1ff7 commit 1c29b92

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ vq = VectorQuantize(
2020
dim = 256,
2121
codebook_size = 512, # codebook size
2222
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
23-
commitment = 1. # the weight on the commitment loss
23+
commitment_weight = 1. # the weight on the commitment loss
2424
)
2525

2626
x = torch.randn(1, 1024, 256)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.4.1',
6+
version = '0.4.2',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,15 @@ def __init__(
253253
n_embed = None,
254254
codebook_dim = None,
255255
decay = 0.8,
256-
commitment = 1.,
257256
orthogonal_reg_weight = 0.,
257+
commitment_weight = None,
258258
eps = 1e-5,
259259
kmeans_init = False,
260260
kmeans_iters = 10,
261261
use_cosine_sim = False,
262262
threshold_ema_dead_code = 0,
263-
channel_last = True
263+
channel_last = True,
264+
commitment = 1. # deprecate in next version, turn off by default
264265
):
265266
super().__init__()
266267
n_embed = default(n_embed, codebook_size)
@@ -273,7 +274,7 @@ def __init__(
273274
else nn.Identity()
274275

275276
self.eps = eps
276-
self.commitment = commitment
277+
self.commitment_weight = default(commitment_weight, commitment)
277278
self.orthogonal_reg_weight = orthogonal_reg_weight
278279

279280
codebook_class = EuclideanCodebook if not use_cosine_sim \
@@ -314,9 +315,9 @@ def forward(self, x):
314315
loss = torch.tensor([0.], device = device, requires_grad = self.training)
315316

316317
if self.training:
317-
if self.commitment > 0:
318+
if self.commitment_weight > 0:
318319
commit_loss = F.mse_loss(quantize.detach(), x)
319-
loss = loss + commit_loss * self.commitment
320+
loss = loss + commit_loss * self.commitment_weight
320321

321322
if self.orthogonal_reg_weight > 0:
322323
orthogonal_reg_loss = orthgonal_loss_fn(self.codebook)

0 commit comments

Comments
 (0)