@@ -104,9 +104,10 @@ def expire_codes_(self, batch_samples):
104104 return
105105
106106 expired_codes = self .cluster_size < self .threshold_ema_dead_code
107- if torch .any (expired_codes ):
108- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
109- self .replace (batch_samples , mask = expired_codes )
107+ if not torch .any (expired_codes ):
108+ return
109+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
110+ self .replace (batch_samples , mask = expired_codes )
110111
111112 def forward (self , x ):
112113 shape , dtype = x .shape , x .dtype
@@ -163,6 +164,7 @@ def __init__(
163164 self .threshold_ema_dead_code = threshold_ema_dead_code
164165
165166 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
167+ self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
166168 self .register_buffer ('embed' , embed )
167169
168170 def init_embed_ (self , data ):
@@ -185,9 +187,10 @@ def expire_codes_(self, batch_samples):
185187 return
186188
187189 expired_codes = self .cluster_size < self .threshold_ema_dead_code
188- if torch .any (expired_codes ):
189- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
190- self .replace (batch_samples , mask = expired_codes )
190+ if not torch .any (expired_codes ):
191+ return
192+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
193+ self .replace (batch_samples , mask = expired_codes )
191194
192195 def forward (self , x ):
193196 shape , dtype = x .shape , x .dtype
@@ -207,6 +210,8 @@ def forward(self, x):
207210
208211 if self .training :
209212 bins = embed_onehot .sum (0 )
213+ ema_inplace (self .cluster_size , bins , self .decay )
214+
210215 zero_mask = (bins == 0 )
211216 bins = bins .masked_fill (zero_mask , 1. )
212217
0 commit comments