Skip to content

Commit 976c3f2

Browse files
committed
cleanup
1 parent eef94ea commit 976c3f2

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.25.1"
3+
version = "1.25.2"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/binary_mapper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ def forward(
9595

9696
# sampling
9797

98-
compare_target = torch.rand_like(logits) if not deterministic else 0.5
99-
100-
sampled_bits = (compare_target <= prob_for_sample).long()
98+
if not deterministic:
99+
sampled_bits = prob_for_sample.bernoulli().long()
100+
else:
101+
sampled_bits = (prob_for_sample > 0.5).long()
101102

102103
indices = (self.power_two * sampled_bits).sum(dim = -1)
103104

0 commit comments

Comments
 (0)