@@ -9,6 +9,9 @@ def exists(val):
99def default (val , d ):
1010 return val if exists (val ) else d
1111
12+ def l2norm (t ):
13+ return F .normalize (t , p = 2 , dim = - 1 )
14+
1215def ema_inplace (moving_avg , new , decay ):
1316 moving_avg .data .mul_ (decay ).add_ (new , alpha = (1 - decay ))
1417
@@ -44,7 +47,7 @@ def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
4447 new_means = new_means / bins [..., None ]
4548
4649 if use_cosine_sim :
47- new_means = F . normalize (new_means , dim = - 1 )
50+ new_means = l2norm (new_means )
4851
4952 means = torch .where (zero_mask [..., None ], means , new_means )
5053
@@ -125,7 +128,7 @@ def __init__(
125128 self .decay = decay
126129
127130 if not kmeans_init :
128- embed = F . normalize (torch .randn (codebook_size , dim ), dim = - 1 )
131+ embed = l2norm (torch .randn (codebook_size , dim ))
129132 else :
130133 embed = torch .zeros (codebook_size , dim )
131134
@@ -144,8 +147,8 @@ def init_embed_(self, data):
144147 def forward (self , x ):
145148 shape , dtype = x .shape , x .dtype
146149 flatten = rearrange (x , '... d -> (...) d' )
147- flatten = F . normalize (flatten , dim = - 1 )
148- embed = F . normalize (self .embed , dim = - 1 )
150+ flatten = l2norm (flatten )
151+ embed = l2norm (self .embed )
149152
150153 if not self .initted :
151154 self .init_embed_ (flatten )
@@ -164,7 +167,7 @@ def forward(self, x):
164167
165168 embed_sum = flatten .t () @ embed_onehot
166169 embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
167- embed_normalized = F . normalize (embed_normalized , dim = - 1 )
170+ embed_normalized = l2norm (embed_normalized )
168171 embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
169172 ema_inplace (self .embed , embed_normalized , self .decay )
170173
0 commit comments