@@ -37,7 +37,8 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
3737 if use_cosine_sim :
3838 dists = samples @ means .t ()
3939 else :
40- diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
40+ diffs = rearrange (samples , 'n d -> n () d' ) \
41+ - rearrange (means , 'c d -> () c d' )
4142 dists = - (diffs ** 2 ).sum (dim = - 1 )
4243
4344 buckets = dists .max (dim = - 1 ).indices
@@ -66,7 +67,8 @@ def __init__(
6667 kmeans_init = False ,
6768 kmeans_iters = 10 ,
6869 decay = 0.8 ,
69- eps = 1e-5
70+ eps = 1e-5 ,
71+ threshold_ema_dead_code = 2
7072 ):
7173 super ().__init__ ()
7274 self .decay = decay
@@ -76,6 +78,7 @@ def __init__(
7678 self .codebook_size = codebook_size
7779 self .kmeans_iters = kmeans_iters
7880 self .eps = eps
81+ self .threshold_ema_dead_code = threshold_ema_dead_code
7982
8083 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
8184 self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
@@ -89,9 +92,22 @@ def init_embed_(self, data):
8992 self .initted .data .copy_ (torch .Tensor ([True ]))
9093
9194 def replace (self , samples , mask ):
92- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
95+ modified_codebook = torch .where (
96+ mask [..., None ],
97+ sample_vectors (samples , self .codebook_size ),
98+ self .embed
99+ )
93100 self .embed .data .copy_ (modified_codebook )
94101
102+ def expire_codes_ (self , batch_samples ):
103+ if self .threshold_ema_dead_code == 0 :
104+ return
105+
106+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
107+ if torch .any (expired_codes ):
108+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
109+ self .replace (batch_samples , mask = expired_codes )
110+
95111 def forward (self , x ):
96112 shape , dtype = x .shape , x .dtype
97113 flatten = rearrange (x , '... d -> (...) d' )
@@ -107,7 +123,7 @@ def forward(self, x):
107123 )
108124
109125 embed_ind = dist .max (dim = - 1 ).indices
110- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (x . dtype )
126+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
111127 embed_ind = embed_ind .view (* shape [:- 1 ])
112128 quantize = F .embedding (embed_ind , self .embed )
113129
@@ -118,6 +134,7 @@ def forward(self, x):
118134 cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
119135 embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
120136 self .embed .data .copy_ (embed_normalized )
137+ self .expire_codes_ (x )
121138
122139 return quantize , embed_ind
123140
@@ -129,7 +146,8 @@ def __init__(
129146 kmeans_init = False ,
130147 kmeans_iters = 10 ,
131148 decay = 0.8 ,
132- eps = 1e-5
149+ eps = 1e-5 ,
150+ threshold_ema_dead_code = 2
133151 ):
134152 super ().__init__ ()
135153 self .decay = decay
@@ -142,20 +160,35 @@ def __init__(
142160 self .codebook_size = codebook_size
143161 self .kmeans_iters = kmeans_iters
144162 self .eps = eps
163+ self .threshold_ema_dead_code = threshold_ema_dead_code
145164
146165 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
147166 self .register_buffer ('embed' , embed )
148167
149168 def init_embed_ (self , data ):
150- embed = kmeans (data , self .codebook_size , self .kmeans_iters , use_cosine_sim = True )
169+ embed = kmeans (data , self .codebook_size , self .kmeans_iters ,
170+ use_cosine_sim = True )
151171 self .embed .data .copy_ (embed )
152172 self .initted .data .copy_ (torch .Tensor ([True ]))
153173
154174 def replace (self , samples , mask ):
155175 samples = l2norm (samples )
156- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
176+ modified_codebook = torch .where (
177+ mask [..., None ],
178+ sample_vectors (samples , self .codebook_size ),
179+ self .embed
180+ )
157181 self .embed .data .copy_ (modified_codebook )
158182
183+ def expire_codes_ (self , batch_samples ):
184+ if self .threshold_ema_dead_code == 0 :
185+ return
186+
187+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
188+ if torch .any (expired_codes ):
189+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
190+ self .replace (batch_samples , mask = expired_codes )
191+
159192 def forward (self , x ):
160193 shape , dtype = x .shape , x .dtype
161194 flatten = rearrange (x , '... d -> (...) d' )
@@ -180,8 +213,10 @@ def forward(self, x):
180213 embed_sum = flatten .t () @ embed_onehot
181214 embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
182215 embed_normalized = l2norm (embed_normalized )
183- embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
216+ embed_normalized = torch .where (zero_mask [..., None ], embed ,
217+ embed_normalized )
184218 ema_inplace (self .embed , embed_normalized , self .decay )
219+ self .expire_codes_ (x )
185220
186221 return quantize , embed_ind
187222
@@ -200,59 +235,41 @@ def __init__(
200235 kmeans_init = False ,
201236 kmeans_iters = 10 ,
202237 use_cosine_sim = False ,
203- max_codebook_misses_before_expiry = 0
238+ threshold_ema_dead_code = 0
204239 ):
205240 super ().__init__ ()
206241 n_embed = default (n_embed , codebook_size )
207242
208243 codebook_dim = default (codebook_dim , dim )
209244 requires_projection = codebook_dim != dim
210- self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
211- self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
245+ self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection \
246+ else nn .Identity ()
247+ self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection \
248+ else nn .Identity ()
212249
213250 self .eps = eps
214251 self .commitment = commitment
215252
216- klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
253+ codebook_class = EuclideanCodebook if not use_cosine_sim \
254+ else CosineSimCodebook
217255
218- self ._codebook = klass (
256+ self ._codebook = codebook_class (
219257 dim = codebook_dim ,
220258 codebook_size = n_embed ,
221259 kmeans_init = kmeans_init ,
222260 kmeans_iters = kmeans_iters ,
223261 decay = decay ,
224- eps = eps
262+ eps = eps ,
263+ threshold_ema_dead_code = threshold_ema_dead_code
225264 )
226265
227266 self .codebook_size = codebook_size
228- self .max_codebook_misses_before_expiry = max_codebook_misses_before_expiry
229-
230- if max_codebook_misses_before_expiry > 0 :
231- codebook_misses = torch .zeros (codebook_size )
232- self .register_buffer ('codebook_misses' , codebook_misses )
233267
234268 @property
235269 def codebook (self ):
236270 return self ._codebook .codebook
237271
238- def expire_codes_ (self , embed_ind , batch_samples ):
239- if self .max_codebook_misses_before_expiry == 0 :
240- return
241-
242- embed_ind = rearrange (embed_ind , '... -> (...)' )
243- misses = torch .bincount (embed_ind , minlength = self .codebook_size ) == 0
244- self .codebook_misses += misses
245-
246- expired_codes = self .codebook_misses >= self .max_codebook_misses_before_expiry
247- if not torch .any (expired_codes ):
248- return
249-
250- self .codebook_misses .masked_fill_ (expired_codes , 0 )
251- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
252- self ._codebook .replace (batch_samples , mask = expired_codes )
253-
254272 def forward (self , x ):
255- dtype = x .dtype
256273 x = self .project_in (x )
257274
258275 quantize , embed_ind = self ._codebook (x )
@@ -262,7 +279,6 @@ def forward(self, x):
262279 if self .training :
263280 commit_loss = F .mse_loss (quantize .detach (), x ) * self .commitment
264281 quantize = x + (quantize - x ).detach ()
265- self .expire_codes_ (embed_ind , x )
266282
267283 quantize = self .project_out (quantize )
268284 return quantize , embed_ind , commit_loss
0 commit comments