Skip to content

Commit eef94ea

Browse files
committed
add deterministic on eval option for binary mapper
1 parent ec49a3d commit eef94ea

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.25.0"
3+
version = "1.25.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/binary_mapper.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class BinaryMapper(Module):
4646
def __init__(
4747
self,
4848
bits = 1,
49-
kl_loss_threshold = NAT # 1 bit
49+
kl_loss_threshold = NAT, # 1 bit
50+
deterministic_on_eval = False
5051
):
5152
super().__init__()
5253

@@ -64,14 +65,21 @@ def __init__(
6465
self.kl_loss_threshold = kl_loss_threshold
6566
self.register_buffer('zero', tensor(0.), persistent = False)
6667

68+
# eval behavior
69+
70+
self.deterministic_on_eval = deterministic_on_eval
71+
6772
def forward(
6873
self,
6974
logits,
7075
temperature = 1.,
7176
straight_through = None,
7277
calc_aux_loss = None,
73-
return_indices = False
78+
deterministic = None,
79+
return_indices = False,
7480
):
81+
deterministic = default(deterministic, self.deterministic_on_eval and not self.training)
82+
7583
straight_through = default(straight_through, self.training)
7684
calc_aux_loss = default(calc_aux_loss, self.training)
7785

@@ -87,7 +95,10 @@ def forward(
8795

8896
# sampling
8997

90-
sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
98+
compare_target = torch.rand_like(logits) if not deterministic else 0.5
99+
100+
sampled_bits = (compare_target <= prob_for_sample).long()
101+
91102
indices = (self.power_two * sampled_bits).sum(dim = -1)
92103

93104
one_hot = F.one_hot(indices, self.num_codes).float()
@@ -143,3 +154,8 @@ def forward(
143154
assert sparse_one_hot.shape == (3, 4, 2 ** 8)
144155
assert indices.shape == (3, 4)
145156
assert aux_loss.numel() == 1
157+
158+
binary_mapper.eval()
159+
sparse_one_hot1, _ = binary_mapper(logits, deterministic = True)
160+
sparse_one_hot2, _ = binary_mapper(logits, deterministic = True)
161+
assert torch.allclose(sparse_one_hot1, sparse_one_hot2)

0 commit comments

Comments
 (0)