Skip to content

Commit cbb5339

Browse files
committed
code formatting
1 parent b1f5d8e commit cbb5339

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
3737
if use_cosine_sim:
3838
dists = samples @ means.t()
3939
else:
40-
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
40+
diffs = rearrange(samples, 'n d -> n () d') \
41+
- rearrange(means, 'c d -> () c d')
4142
dists = -(diffs ** 2).sum(dim = -1)
4243

4344
buckets = dists.max(dim = -1).indices
@@ -89,7 +90,11 @@ def init_embed_(self, data):
8990
self.initted.data.copy_(torch.Tensor([True]))
9091

9192
def replace(self, samples, mask):
92-
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
93+
modified_codebook = torch.where(
94+
mask[..., None],
95+
sample_vectors(samples, self.codebook_size),
96+
self.embed
97+
)
9398
self.embed.data.copy_(modified_codebook)
9499

95100
def forward(self, x):
@@ -147,13 +152,18 @@ def __init__(
147152
self.register_buffer('embed', embed)
148153

149154
def init_embed_(self, data):
150-
embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True)
155+
embed = kmeans(data, self.codebook_size, self.kmeans_iters,
156+
use_cosine_sim = True)
151157
self.embed.data.copy_(embed)
152158
self.initted.data.copy_(torch.Tensor([True]))
153159

154160
def replace(self, samples, mask):
155161
samples = l2norm(samples)
156-
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
162+
modified_codebook = torch.where(
163+
mask[..., None],
164+
sample_vectors(samples, self.codebook_size),
165+
self.embed
166+
)
157167
self.embed.data.copy_(modified_codebook)
158168

159169
def forward(self, x):
@@ -180,7 +190,8 @@ def forward(self, x):
180190
embed_sum = flatten.t() @ embed_onehot
181191
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
182192
embed_normalized = l2norm(embed_normalized)
183-
embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized)
193+
embed_normalized = torch.where(zero_mask[..., None], embed,
194+
embed_normalized)
184195
ema_inplace(self.embed, embed_normalized, self.decay)
185196

186197
return quantize, embed_ind
@@ -207,13 +218,16 @@ def __init__(
207218

208219
codebook_dim = default(codebook_dim, dim)
209220
requires_projection = codebook_dim != dim
210-
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
211-
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
221+
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \
222+
else nn.Identity()
223+
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \
224+
else nn.Identity()
212225

213226
self.eps = eps
214227
self.commitment = commitment
215228

216-
klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
229+
codebook_class = EuclideanCodebook if not use_cosine_sim \
230+
else CosineSimCodebook
217231

218232
self._codebook = klass(
219233
dim = codebook_dim,

0 commit comments

Comments
 (0)