Skip to content

Commit ec49a3d

Browse files
committed
add a self-contained file for the proposed BinaryMapper from Free Transformer https://arxiv.org/abs/2510.17558
1 parent d8e82a8 commit ec49a3d

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,3 +791,16 @@ assert loss.item() >= 0
791791
url = {https://arxiv.org/abs/2509.26469},
792792
}
793793
```
794+
795+
```bibtex
796+
@misc{fleuret2025freetransformer,
797+
title = {The Free Transformer},
798+
author = {François Fleuret},
799+
year = {2025},
800+
eprint = {2510.17558},
801+
archivePrefix = {arXiv},
802+
primaryClass = {cs.LG},
803+
url = {https://arxiv.org/abs/2510.17558},
804+
}
805+
```
806+

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.24.2"
3+
version = "1.25.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@
1010
from vector_quantize_pytorch.sim_vq import SimVQ
1111
from vector_quantize_pytorch.residual_sim_vq import ResidualSimVQ
1212

13+
from vector_quantize_pytorch.binary_mapper import BinaryMapper
14+
1315
from vector_quantize_pytorch.utils import Sequential
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
# proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
4+
5+
from math import log
6+
7+
import torch
8+
from torch import nn, tensor, arange
9+
import torch.nn.functional as F
10+
from torch.nn import Module
11+
12+
from einops import einsum, pack, unpack
13+
14+
# constants
15+
16+
NAT = log(2)
17+
18+
# helper functions
19+
20+
def exists(v):
21+
return v is not None
22+
23+
def default(v, d):
24+
return v if exists(v) else d
25+
26+
# tensor helpers
27+
28+
def binary_entropy(logits):
29+
prob = logits.sigmoid()
30+
not_prob = 1. - prob
31+
return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
32+
33+
def pack_with_inverse(t, pattern):
34+
packed, ps = pack([t], pattern)
35+
36+
def inverse(out, inv_pattern = None):
37+
inv_pattern = default(inv_pattern, pattern)
38+
unpacked, = unpack(out, ps, inv_pattern)
39+
return unpacked
40+
41+
return packed, inverse
42+
43+
# binary mapper
44+
45+
class BinaryMapper(Module):
46+
def __init__(
47+
self,
48+
bits = 1,
49+
kl_loss_threshold = NAT # 1 bit
50+
):
51+
super().__init__()
52+
53+
self.bits = bits
54+
self.num_codes = 2 ** bits
55+
56+
power_two = 2 ** arange(bits)
57+
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
58+
59+
self.register_buffer('power_two', power_two, persistent = False)
60+
self.register_buffer('codes', codes, persistent = False)
61+
62+
# aux loss
63+
64+
self.kl_loss_threshold = kl_loss_threshold
65+
self.register_buffer('zero', tensor(0.), persistent = False)
66+
67+
def forward(
68+
self,
69+
logits,
70+
temperature = 1.,
71+
straight_through = None,
72+
calc_aux_loss = None,
73+
return_indices = False
74+
):
75+
straight_through = default(straight_through, self.training)
76+
calc_aux_loss = default(calc_aux_loss, self.training)
77+
78+
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
79+
80+
# allow for any number of leading dimensions
81+
82+
logits, inverse_pack_lead_dims = pack_with_inverse(logits, '* bits')
83+
84+
# temperature and prob for sampling
85+
86+
prob_for_sample = (logits / temperature).sigmoid()
87+
88+
# sampling
89+
90+
sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
91+
indices = (self.power_two * sampled_bits).sum(dim = -1)
92+
93+
one_hot = F.one_hot(indices, self.num_codes).float()
94+
95+
# maybe calculate aux loss
96+
97+
aux_kl_loss = self.zero
98+
99+
if calc_aux_loss:
100+
# calculate negative entropy
101+
102+
kl_div = self.bits * NAT - binary_entropy(logits)
103+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
104+
105+
# maybe straight through
106+
107+
if straight_through:
108+
# get the soft G for the gradients and do a straight through
109+
110+
soft_G = (
111+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113+
).exp()
114+
115+
# straight through
116+
117+
one_hot = one_hot + soft_G - soft_G.detach()
118+
119+
# inverse pack
120+
121+
one_hot = inverse_pack_lead_dims(one_hot)
122+
indices = inverse_pack_lead_dims(indices, '*')
123+
124+
# returning
125+
126+
if not return_indices:
127+
return one_hot, aux_kl_loss
128+
129+
# also allow for returning indices, even though it can be derived from sparse output with an argmax
130+
131+
return one_hot, indices, aux_kl_loss
132+
133+
# allow for quick copy paste
134+
135+
if __name__ == '__main__':
136+
137+
binary_mapper = BinaryMapper(bits = 8)
138+
139+
logits = torch.randn(3, 4, 8)
140+
141+
sparse_one_hot, indices, aux_loss = binary_mapper(logits, return_indices = True)
142+
143+
assert sparse_one_hot.shape == (3, 4, 2 ** 8)
144+
assert indices.shape == (3, 4)
145+
assert aux_loss.numel() == 1

0 commit comments

Comments
 (0)