@@ -15,7 +15,7 @@ def ema_inplace(moving_avg, new, decay):
1515def laplace_smoothing (x , n_categories , eps = 1e-5 ):
1616 return (x + eps ) / (x .sum () + n_categories * eps )
1717
18- def kmeans (x , num_clusters , num_iters = 10 ):
18+ def kmeans (x , num_clusters , num_iters = 10 , use_cosine_sim = False ):
1919 samples = rearrange (x , '... d -> (...) d' )
2020 num_samples , dim , dtype , device = * samples .shape , x .dtype , x .device
2121
@@ -27,9 +27,13 @@ def kmeans(x, num_clusters, num_iters = 10):
2727 means = samples [indices ]
2828
2929 for _ in range (num_iters ):
30- diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
31- dists = (diffs ** 2 ).sum (dim = - 1 )
32- buckets = dists .argmin (dim = - 1 )
30+ if use_cosine_sim :
31+ dists = samples @ means .t ()
32+ buckets = dists .max (dim = - 1 ).indices
33+ else :
34+ diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
35+ dists = (diffs ** 2 ).sum (dim = - 1 )
36+ buckets = dists .argmin (dim = - 1 )
3337
3438 bins = torch .bincount (buckets , minlength = num_clusters )
3539 zero_mask = bins == 0
@@ -38,86 +42,186 @@ def kmeans(x, num_clusters, num_iters = 10):
3842 new_means = buckets .new_zeros (num_clusters , dim , dtype = dtype )
3943 new_means .scatter_add_ (0 , repeat (buckets , 'n -> n d' , d = dim ), samples )
4044 new_means = new_means / bins [..., None ]
45+
46+ if use_cosine_sim :
47+ new_means = F .normalize (new_means , dim = - 1 )
48+
4149 means = torch .where (zero_mask [..., None ], means , new_means )
4250
43- return rearrange ( means , 'n d -> d n' )
51+ return means
4452
45- class VectorQuantize (nn .Module ):
53+ # distance types
54+
55+ class EuclideanCodebook (nn .Module ):
4656 def __init__ (
4757 self ,
4858 dim ,
4959 codebook_size ,
50- decay = 0.8 ,
51- commitment = 1. ,
52- eps = 1e-5 ,
53- n_embed = None ,
5460 kmeans_init = False ,
5561 kmeans_iters = 10 ,
56- codebook_dim = None
62+ decay = 0.8 ,
63+ eps = 1e-5
5764 ):
5865 super ().__init__ ()
59- n_embed = default (n_embed , codebook_size )
60- self .n_embed = n_embed
61-
62- codebook_dim = default (codebook_dim , dim )
63- requires_projection = codebook_dim != dim
64- self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
65- self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
66-
6766 self .decay = decay
68- self .eps = eps
69- self .commitment = commitment
70-
7167 init_fn = torch .randn if not kmeans_init else torch .zeros
72- embed = init_fn (codebook_dim , n_embed )
68+ embed = init_fn (codebook_size , dim )
7369
70+ self .codebook_size = codebook_size
7471 self .kmeans_iters = kmeans_iters
72+ self .eps = eps
73+
7574 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
76- self .register_buffer ('cluster_size' , torch .zeros (n_embed ))
75+ self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
7776 self .register_buffer ('embed' , embed )
7877 self .register_buffer ('embed_avg' , embed .clone ())
7978
80- @property
81- def codebook (self ):
82- return self .embed .transpose (0 , 1 )
83-
8479 def init_embed_ (self , data ):
85- embed = kmeans (data , self .n_embed , self .kmeans_iters )
80+ embed = kmeans (data , self .codebook_size , self .kmeans_iters )
8681 self .embed .data .copy_ (embed )
8782 self .embed_avg .data .copy_ (embed .clone ())
8883 self .initted .data .copy_ (torch .Tensor ([True ]))
8984
90- def forward (self , input ):
91- input = self .project_in (input )
92-
85+ def forward (self , x ):
9386 if not self .initted :
94- self .init_embed_ (input )
87+ self .init_embed_ (x )
88+
89+ shape , dtype = x .shape , x .dtype
90+ flatten = rearrange (x , '... d -> (...) d' )
91+ embed = self .embed .t ()
9592
96- dtype = input .dtype
97- flatten = rearrange (input , '... d -> (...) d' )
98- dist = (
93+ dist = - (
9994 flatten .pow (2 ).sum (1 , keepdim = True )
100- - 2 * flatten @ self . embed
101- + self . embed .pow (2 ).sum (0 , keepdim = True )
95+ - 2 * flatten @ embed
96+ + embed .pow (2 ).sum (0 , keepdim = True )
10297 )
10398
104- _ , embed_ind = (- dist ).max (1 )
105- embed_onehot = F .one_hot (embed_ind , self .n_embed ).type (dtype )
106- embed_ind = embed_ind .view (* input .shape [:- 1 ])
107-
108- commit_loss = 0.
109- quantize = F .embedding (embed_ind , self .embed .transpose (0 , 1 ))
99+ embed_ind = dist .max (dim = - 1 ).indices
100+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (x .dtype )
101+ embed_ind = embed_ind .view (* shape [:- 1 ])
102+ quantize = F .embedding (embed_ind , self .embed )
110103
111104 if self .training :
112105 ema_inplace (self .cluster_size , embed_onehot .sum (0 ), self .decay )
113- embed_sum = flatten .transpose ( 0 , 1 ) @ embed_onehot
114- ema_inplace (self .embed_avg , embed_sum , self .decay )
115- cluster_size = laplace_smoothing (self .cluster_size , self .n_embed , self .eps ) * self .cluster_size .sum ()
116- embed_normalized = self .embed_avg / cluster_size .unsqueeze (0 )
106+ embed_sum = flatten .t ( ) @ embed_onehot
107+ ema_inplace (self .embed_avg , embed_sum . t () , self .decay )
108+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
109+ embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
117110 self .embed .data .copy_ (embed_normalized )
118111
119- commit_loss = F .mse_loss (quantize .detach (), input ) * self .commitment
120- quantize = input + (quantize - input ).detach ()
112+ return quantize , embed_ind
113+
114+ class CosineSimCodebook (nn .Module ):
115+ def __init__ (
116+ self ,
117+ dim ,
118+ codebook_size ,
119+ kmeans_init = False ,
120+ kmeans_iters = 10 ,
121+ decay = 0.8 ,
122+ eps = 1e-5
123+ ):
124+ super ().__init__ ()
125+ self .decay = decay
126+
127+ if not kmeans_init :
128+ embed = F .normalize (torch .randn (codebook_size , dim ), dim = - 1 )
129+ else :
130+ embed = torch .zeros (codebook_size , dim )
131+
132+ self .codebook_size = codebook_size
133+ self .kmeans_iters = kmeans_iters
134+ self .eps = eps
135+
136+ self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
137+ self .register_buffer ('embed' , embed )
138+
139+ def init_embed_ (self , data ):
140+ embed = kmeans (data , self .codebook_size , self .kmeans_iters , use_cosine_sim = True )
141+ self .embed .data .copy_ (embed )
142+ self .initted .data .copy_ (torch .Tensor ([True ]))
143+
144+ def forward (self , x ):
145+ shape , dtype = x .shape , x .dtype
146+ flatten = rearrange (x , '... d -> (...) d' )
147+ flatten = F .normalize (flatten , dim = - 1 )
148+ embed = F .normalize (self .embed , dim = - 1 )
149+
150+ if not self .initted :
151+ self .init_embed_ (flatten )
152+
153+ dist = flatten @ embed .t ()
154+ embed_ind = dist .max (dim = - 1 ).indices
155+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
156+ embed_ind = embed_ind .view (* shape [:- 1 ])
157+
158+ quantize = F .embedding (embed_ind , self .embed )
159+
160+ if self .training :
161+ bins = embed_onehot .sum (0 )
162+ zero_mask = (bins == 0 )
163+ bins = bins .masked_fill (zero_mask , 1. )
164+
165+ embed_sum = flatten .t () @ embed_onehot
166+ embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
167+ embed_normalized = F .normalize (embed_normalized , dim = - 1 )
168+ embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
169+ ema_inplace (self .embed , embed_normalized , self .decay )
170+
171+ return quantize , embed_ind
172+
173+ # main class
174+
175+ class VectorQuantize (nn .Module ):
176+ def __init__ (
177+ self ,
178+ dim ,
179+ codebook_size ,
180+ n_embed = None ,
181+ codebook_dim = None ,
182+ decay = 0.8 ,
183+ commitment = 1. ,
184+ eps = 1e-5 ,
185+ kmeans_init = False ,
186+ kmeans_iters = 10 ,
187+ use_cosine_sim = False
188+ ):
189+ super ().__init__ ()
190+ n_embed = default (n_embed , codebook_size )
191+
192+ codebook_dim = default (codebook_dim , dim )
193+ requires_projection = codebook_dim != dim
194+ self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
195+ self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
196+
197+ self .eps = eps
198+ self .commitment = commitment
199+
200+ klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
201+
202+ self ._codebook = klass (
203+ dim = codebook_dim ,
204+ codebook_size = n_embed ,
205+ kmeans_init = kmeans_init ,
206+ kmeans_iters = kmeans_iters ,
207+ decay = decay ,
208+ eps = eps
209+ )
210+
211+ @property
212+ def codebook (self ):
213+ return self ._codebook .codebook
214+
215+ def forward (self , x ):
216+ dtype = x .dtype
217+ x = self .project_in (x )
218+
219+ quantize , embed_ind = self ._codebook (x )
220+
221+ commit_loss = 0.
222+ if self .training :
223+ commit_loss = F .mse_loss (quantize .detach (), x ) * self .commitment
224+ quantize = x + (quantize - x ).detach ()
121225
122226 quantize = self .project_out (quantize )
123227 return quantize , embed_ind , commit_loss
0 commit comments