We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eef94ea commit 976c3f2Copy full SHA for 976c3f2
pyproject.toml
@@ -1,6 +1,6 @@
1
[project]
2
name = "vector-quantize-pytorch"
3
-version = "1.25.1"
+version = "1.25.2"
4
description = "Vector Quantization - Pytorch"
5
authors = [
6
{ name = "Phil Wang", email = "[email protected]" }
vector_quantize_pytorch/binary_mapper.py
@@ -95,9 +95,10 @@ def forward(
95
96
# sampling
97
98
- compare_target = torch.rand_like(logits) if not deterministic else 0.5
99
-
100
- sampled_bits = (compare_target <= prob_for_sample).long()
+ if not deterministic:
+ sampled_bits = prob_for_sample.bernoulli().long()
+ else:
101
+ sampled_bits = (prob_for_sample > 0.5).long()
102
103
indices = (self.power_two * sampled_bits).sum(dim = -1)
104
0 commit comments