Skip to content

Commit 9ad29ef

Browse files
committed
for clarity
1 parent a1b4a71 commit 9ad29ef

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ def exists(val):
99
def default(val, d):
1010
return val if exists(val) else d
1111

12+
def l2norm(t):
13+
return F.normalize(t, p = 2, dim = -1)
14+
1215
def ema_inplace(moving_avg, new, decay):
1316
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
1417

@@ -44,7 +47,7 @@ def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
4447
new_means = new_means / bins[..., None]
4548

4649
if use_cosine_sim:
47-
new_means = F.normalize(new_means, dim = -1)
50+
new_means = l2norm(new_means)
4851

4952
means = torch.where(zero_mask[..., None], means, new_means)
5053

@@ -125,7 +128,7 @@ def __init__(
125128
self.decay = decay
126129

127130
if not kmeans_init:
128-
embed = F.normalize(torch.randn(codebook_size, dim), dim = -1)
131+
embed = l2norm(torch.randn(codebook_size, dim))
129132
else:
130133
embed = torch.zeros(codebook_size, dim)
131134

@@ -144,8 +147,8 @@ def init_embed_(self, data):
144147
def forward(self, x):
145148
shape, dtype = x.shape, x.dtype
146149
flatten = rearrange(x, '... d -> (...) d')
147-
flatten = F.normalize(flatten, dim = -1)
148-
embed = F.normalize(self.embed, dim = - 1)
150+
flatten = l2norm(flatten)
151+
embed = l2norm(self.embed)
149152

150153
if not self.initted:
151154
self.init_embed_(flatten)
@@ -164,7 +167,7 @@ def forward(self, x):
164167

165168
embed_sum = flatten.t() @ embed_onehot
166169
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
167-
embed_normalized = F.normalize(embed_normalized, dim = -1)
170+
embed_normalized = l2norm(embed_normalized)
168171
embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized)
169172
ema_inplace(self.embed, embed_normalized, self.decay)
170173

0 commit comments

Comments
 (0)