Skip to content

Commit 5141c19

Browse files
committed
rename
1 parent f49a311 commit 5141c19

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
5+
def ema_inplace(moving_avg, new, decay):
6+
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
7+
8+
def laplace_smoothing(x, n_categories, eps=1e-5):
9+
return (x + eps) / (x.sum() + n_categories * eps)
10+
11+
class VectorQuantize(nn.Module):
12+
def __init__(self, dim, n_embed, decay=0.8, commitment=1., eps=1e-5):
13+
super().__init__()
14+
15+
self.dim = dim
16+
self.n_embed = n_embed
17+
self.decay = decay
18+
self.eps = eps
19+
self.commitment = commitment
20+
21+
embed = torch.randn(dim, n_embed)
22+
self.register_buffer('embed', embed)
23+
self.register_buffer('cluster_size', torch.zeros(n_embed))
24+
self.register_buffer('embed_avg', embed.clone())
25+
26+
def forward(self, input):
27+
dtype = input.dtype
28+
flatten = input.reshape(-1, self.dim)
29+
dist = (
30+
flatten.pow(2).sum(1, keepdim=True)
31+
- 2 * flatten @ self.embed
32+
+ self.embed.pow(2).sum(0, keepdim=True)
33+
)
34+
_, embed_ind = (-dist).max(1)
35+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
36+
embed_ind = embed_ind.view(*input.shape[:-1])
37+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
38+
39+
if self.training:
40+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
41+
embed_sum = flatten.transpose(0, 1) @ embed_onehot
42+
ema_inplace(self.embed_avg, embed_sum, self.decay)
43+
cluster_size = laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
44+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
45+
self.embed.data.copy_(embed_normalized)
46+
47+
loss = F.mse_loss(quantize.detach(), input) * self.commitment
48+
quantize = input + (quantize - input).detach()
49+
return quantize, loss

0 commit comments

Comments
 (0)