11import torch
2- from torch import nn
2+ from torch import nn , einsum
33import torch .nn .functional as F
4+ from einops import rearrange , repeat
45
56def exists (val ):
67 return val is not None
@@ -11,9 +12,36 @@ def default(val, d):
1112def ema_inplace (moving_avg , new , decay ):
1213 moving_avg .data .mul_ (decay ).add_ (new , alpha = (1 - decay ))
1314
14- def laplace_smoothing (x , n_categories , eps = 1e-5 ):
15+ def laplace_smoothing (x , n_categories , eps = 1e-5 ):
1516 return (x + eps ) / (x .sum () + n_categories * eps )
1617
18+ def kmeans (x , num_clusters , num_iters = 10 ):
19+ samples = rearrange (x , '... d -> (...) d' )
20+ num_samples , dim , dtype , device = * samples .shape , x .dtype , x .device
21+
22+ if num_samples >= num_clusters :
23+ indices = torch .randperm (num_samples , device = device )[:num_clusters ]
24+ else :
25+ indices = torch .randint (0 , num_samples , (num_clusters ,), device = device )
26+
27+ means = samples [indices ]
28+
29+ for _ in range (num_iters ):
30+ diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
31+ dists = (diffs ** 2 ).sum (dim = - 1 )
32+ buckets = dists .argmin (dim = - 1 )
33+
34+ bins = torch .bincount (buckets , minlength = num_clusters )
35+ zero_mask = bins == 0
36+ bins = bins .masked_fill (zero_mask , 1 )
37+
38+ new_means = buckets .new_zeros (num_clusters , dim , dtype = dtype )
39+ new_means .scatter_add_ (0 , repeat (buckets , 'n -> n d' , d = dim ), samples )
40+ new_means = new_means / bins [..., None ]
41+ means = torch .where (zero_mask [..., None ], means , new_means )
42+
43+ return rearrange (means , 'n d -> d n' )
44+
1745class VectorQuantize (nn .Module ):
1846 def __init__ (
1947 self ,
@@ -23,6 +51,8 @@ def __init__(
2351 commitment = 1. ,
2452 eps = 1e-5 ,
2553 n_embed = None ,
54+ kmeans_init = False ,
55+ kmeans_iters = 10
2656 ):
2757 super ().__init__ ()
2858 n_embed = default (n_embed , codebook_size )
@@ -33,26 +63,42 @@ def __init__(
3363 self .eps = eps
3464 self .commitment = commitment
3565
36- embed = torch .randn (dim , n_embed )
37- self .register_buffer ('embed' , embed )
66+ init_fn = torch .randn if not kmeans_init else torch .zeros
67+ embed = init_fn (dim , n_embed )
68+
69+ self .kmeans_iters = kmeans_iters
70+ self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
3871 self .register_buffer ('cluster_size' , torch .zeros (n_embed ))
72+ self .register_buffer ('embed' , embed )
3973 self .register_buffer ('embed_avg' , embed .clone ())
4074
4175 @property
4276 def codebook (self ):
4377 return self .embed .transpose (0 , 1 )
4478
79+ def init_embed_ (self , data ):
80+ embed = kmeans (data , self .n_embed , self .kmeans_iters )
81+ self .embed .data .copy_ (embed )
82+ self .embed_avg .data .copy_ (embed .clone ())
83+ self .initted .data .copy_ (torch .Tensor ([True ]))
84+
4585 def forward (self , input ):
86+ if not self .initted :
87+ self .init_embed_ (input )
88+
4689 dtype = input .dtype
4790 flatten = input .reshape (- 1 , self .dim )
4891 dist = (
4992 flatten .pow (2 ).sum (1 , keepdim = True )
5093 - 2 * flatten @ self .embed
5194 + self .embed .pow (2 ).sum (0 , keepdim = True )
5295 )
96+
5397 _ , embed_ind = (- dist ).max (1 )
5498 embed_onehot = F .one_hot (embed_ind , self .n_embed ).type (dtype )
5599 embed_ind = embed_ind .view (* input .shape [:- 1 ])
100+
101+ commit_loss = 0.
56102 quantize = F .embedding (embed_ind , self .embed .transpose (0 , 1 ))
57103
58104 if self .training :
@@ -63,6 +109,7 @@ def forward(self, input):
63109 embed_normalized = self .embed_avg / cluster_size .unsqueeze (0 )
64110 self .embed .data .copy_ (embed_normalized )
65111
66- loss = F .mse_loss (quantize .detach (), input ) * self .commitment
67- quantize = input + (quantize - input ).detach ()
68- return quantize , embed_ind , loss
112+ commit_loss = F .mse_loss (quantize .detach (), input ) * self .commitment
113+ quantize = input + (quantize - input ).detach ()
114+
115+ return quantize , embed_ind , commit_loss
0 commit comments