@@ -37,7 +37,8 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
3737 if use_cosine_sim :
3838 dists = samples @ means .t ()
3939 else :
40- diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
40+ diffs = rearrange (samples , 'n d -> n () d' ) \
41+ - rearrange (means , 'c d -> () c d' )
4142 dists = - (diffs ** 2 ).sum (dim = - 1 )
4243
4344 buckets = dists .max (dim = - 1 ).indices
@@ -89,7 +90,11 @@ def init_embed_(self, data):
8990 self .initted .data .copy_ (torch .Tensor ([True ]))
9091
9192 def replace (self , samples , mask ):
92- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
93+ modified_codebook = torch .where (
94+ mask [..., None ],
95+ sample_vectors (samples , self .codebook_size ),
96+ self .embed
97+ )
9398 self .embed .data .copy_ (modified_codebook )
9499
95100 def forward (self , x ):
@@ -147,13 +152,18 @@ def __init__(
147152 self .register_buffer ('embed' , embed )
148153
149154 def init_embed_ (self , data ):
150- embed = kmeans (data , self .codebook_size , self .kmeans_iters , use_cosine_sim = True )
155+ embed = kmeans (data , self .codebook_size , self .kmeans_iters ,
156+ use_cosine_sim = True )
151157 self .embed .data .copy_ (embed )
152158 self .initted .data .copy_ (torch .Tensor ([True ]))
153159
154160 def replace (self , samples , mask ):
155161 samples = l2norm (samples )
156- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
162+ modified_codebook = torch .where (
163+ mask [..., None ],
164+ sample_vectors (samples , self .codebook_size ),
165+ self .embed
166+ )
157167 self .embed .data .copy_ (modified_codebook )
158168
159169 def forward (self , x ):
@@ -180,7 +190,8 @@ def forward(self, x):
180190 embed_sum = flatten .t () @ embed_onehot
181191 embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
182192 embed_normalized = l2norm (embed_normalized )
183- embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
193+ embed_normalized = torch .where (zero_mask [..., None ], embed ,
194+ embed_normalized )
184195 ema_inplace (self .embed , embed_normalized , self .decay )
185196
186197 return quantize , embed_ind
@@ -207,13 +218,16 @@ def __init__(
207218
208219 codebook_dim = default (codebook_dim , dim )
209220 requires_projection = codebook_dim != dim
210- self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
211- self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
221+ self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection \
222+ else nn .Identity ()
223+ self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection \
224+ else nn .Identity ()
212225
213226 self .eps = eps
214227 self .commitment = commitment
215228
216- klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
229+ codebook_class = EuclideanCodebook if not use_cosine_sim \
230+ else CosineSimCodebook
217231
218232 self ._codebook = klass (
219233 dim = codebook_dim ,
0 commit comments