@@ -46,7 +46,8 @@ class BinaryMapper(Module):
4646 def __init__ (
4747 self ,
4848 bits = 1 ,
49- kl_loss_threshold = NAT # 1 bit
49+ kl_loss_threshold = NAT , # 1 bit
50+ deterministic_on_eval = False
5051 ):
5152 super ().__init__ ()
5253
@@ -64,14 +65,21 @@ def __init__(
6465 self .kl_loss_threshold = kl_loss_threshold
6566 self .register_buffer ('zero' , tensor (0. ), persistent = False )
6667
68+ # eval behavior
69+
70+ self .deterministic_on_eval = deterministic_on_eval
71+
6772 def forward (
6873 self ,
6974 logits ,
7075 temperature = 1. ,
7176 straight_through = None ,
7277 calc_aux_loss = None ,
73- return_indices = False
78+ deterministic = None ,
79+ return_indices = False ,
7480 ):
81+ deterministic = default (deterministic , self .deterministic_on_eval and not self .training )
82+
7583 straight_through = default (straight_through , self .training )
7684 calc_aux_loss = default (calc_aux_loss , self .training )
7785
@@ -87,7 +95,10 @@ def forward(
8795
8896 # sampling
8997
90- sampled_bits = (torch .rand_like (logits ) <= prob_for_sample ).long ()
98+ compare_target = torch .rand_like (logits ) if not deterministic else 0.5
99+
100+ sampled_bits = (compare_target <= prob_for_sample ).long ()
101+
91102 indices = (self .power_two * sampled_bits ).sum (dim = - 1 )
92103
93104 one_hot = F .one_hot (indices , self .num_codes ).float ()
@@ -143,3 +154,8 @@ def forward(
143154 assert sparse_one_hot .shape == (3 , 4 , 2 ** 8 )
144155 assert indices .shape == (3 , 4 )
145156 assert aux_loss .numel () == 1
157+
158+ binary_mapper .eval ()
159+ sparse_one_hot1 , _ = binary_mapper (logits , deterministic = True )
160+ sparse_one_hot2 , _ = binary_mapper (logits , deterministic = True )
161+ assert torch .allclose (sparse_one_hot1 , sparse_one_hot2 )
0 commit comments