@@ -67,7 +67,8 @@ def __init__(
6767 kmeans_init = False ,
6868 kmeans_iters = 10 ,
6969 decay = 0.8 ,
70- eps = 1e-5
70+ eps = 1e-5 ,
71+ threshold_ema_dead_code = 2
7172 ):
7273 super ().__init__ ()
7374 self .decay = decay
@@ -77,6 +78,7 @@ def __init__(
7778 self .codebook_size = codebook_size
7879 self .kmeans_iters = kmeans_iters
7980 self .eps = eps
81+ self .threshold_ema_dead_code = threshold_ema_dead_code
8082
8183 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
8284 self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
@@ -97,6 +99,15 @@ def replace(self, samples, mask):
9799 )
98100 self .embed .data .copy_ (modified_codebook )
99101
102+ def expire_codes_ (self , batch_samples ):
103+ if self .threshold_ema_dead_code == 0 :
104+ return
105+
106+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
107+ if torch .any (expired_codes ):
108+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
109+ self .replace (batch_samples , mask = expired_codes )
110+
100111 def forward (self , x ):
101112 shape , dtype = x .shape , x .dtype
102113 flatten = rearrange (x , '... d -> (...) d' )
@@ -112,7 +123,7 @@ def forward(self, x):
112123 )
113124
114125 embed_ind = dist .max (dim = - 1 ).indices
115- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (x . dtype )
126+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
116127 embed_ind = embed_ind .view (* shape [:- 1 ])
117128 quantize = F .embedding (embed_ind , self .embed )
118129
@@ -123,6 +134,7 @@ def forward(self, x):
123134 cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
124135 embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
125136 self .embed .data .copy_ (embed_normalized )
137+ self .expire_codes_ (x )
126138
127139 return quantize , embed_ind
128140
@@ -134,7 +146,8 @@ def __init__(
134146 kmeans_init = False ,
135147 kmeans_iters = 10 ,
136148 decay = 0.8 ,
137- eps = 1e-5
149+ eps = 1e-5 ,
150+ threshold_ema_dead_code = 2
138151 ):
139152 super ().__init__ ()
140153 self .decay = decay
@@ -147,6 +160,7 @@ def __init__(
147160 self .codebook_size = codebook_size
148161 self .kmeans_iters = kmeans_iters
149162 self .eps = eps
163+ self .threshold_ema_dead_code = threshold_ema_dead_code
150164
151165 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
152166 self .register_buffer ('embed' , embed )
@@ -166,6 +180,15 @@ def replace(self, samples, mask):
166180 )
167181 self .embed .data .copy_ (modified_codebook )
168182
183+ def expire_codes_ (self , batch_samples ):
184+ if self .threshold_ema_dead_code == 0 :
185+ return
186+
187+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
188+ if torch .any (expired_codes ):
189+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
190+ self .replace (batch_samples , mask = expired_codes )
191+
169192 def forward (self , x ):
170193 shape , dtype = x .shape , x .dtype
171194 flatten = rearrange (x , '... d -> (...) d' )
@@ -193,6 +216,7 @@ def forward(self, x):
193216 embed_normalized = torch .where (zero_mask [..., None ], embed ,
194217 embed_normalized )
195218 ema_inplace (self .embed , embed_normalized , self .decay )
219+ self .expire_codes_ (x )
196220
197221 return quantize , embed_ind
198222
@@ -211,7 +235,7 @@ def __init__(
211235 kmeans_init = False ,
212236 kmeans_iters = 10 ,
213237 use_cosine_sim = False ,
214- max_codebook_misses_before_expiry = 0
238+ threshold_ema_dead_code = 0
215239 ):
216240 super ().__init__ ()
217241 n_embed = default (n_embed , codebook_size )
@@ -229,44 +253,23 @@ def __init__(
229253 codebook_class = EuclideanCodebook if not use_cosine_sim \
230254 else CosineSimCodebook
231255
232- self ._codebook = klass (
256+ self ._codebook = codebook_class (
233257 dim = codebook_dim ,
234258 codebook_size = n_embed ,
235259 kmeans_init = kmeans_init ,
236260 kmeans_iters = kmeans_iters ,
237261 decay = decay ,
238- eps = eps
262+ eps = eps ,
263+ threshold_ema_dead_code = threshold_ema_dead_code
239264 )
240265
241266 self .codebook_size = codebook_size
242- self .max_codebook_misses_before_expiry = max_codebook_misses_before_expiry
243-
244- if max_codebook_misses_before_expiry > 0 :
245- codebook_misses = torch .zeros (codebook_size )
246- self .register_buffer ('codebook_misses' , codebook_misses )
247267
248268 @property
249269 def codebook (self ):
250270 return self ._codebook .codebook
251271
252- def expire_codes_ (self , embed_ind , batch_samples ):
253- if self .max_codebook_misses_before_expiry == 0 :
254- return
255-
256- embed_ind = rearrange (embed_ind , '... -> (...)' )
257- misses = torch .bincount (embed_ind , minlength = self .codebook_size ) == 0
258- self .codebook_misses += misses
259-
260- expired_codes = self .codebook_misses >= self .max_codebook_misses_before_expiry
261- if not torch .any (expired_codes ):
262- return
263-
264- self .codebook_misses .masked_fill_ (expired_codes , 0 )
265- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
266- self ._codebook .replace (batch_samples , mask = expired_codes )
267-
268272 def forward (self , x ):
269- dtype = x .dtype
270273 x = self .project_in (x )
271274
272275 quantize , embed_ind = self ._codebook (x )
@@ -276,7 +279,6 @@ def forward(self, x):
276279 if self .training :
277280 commit_loss = F .mse_loss (quantize .detach (), x ) * self .commitment
278281 quantize = x + (quantize - x ).detach ()
279- self .expire_codes_ (embed_ind , x )
280282
281283 quantize = self .project_out (quantize )
282284 return quantize , embed_ind , commit_loss
0 commit comments