1111import torch
1212import torch .nn as nn
1313from torch .nn import Module
14- from torch import Tensor , int32
14+ from torch import tensor , Tensor , int32
1515from torch .amp import autocast
1616
1717import einx
@@ -47,11 +47,12 @@ def unpack_one(t, ps, pattern):
4747# tensor helpers
4848
4949def round_ste (z ):
50- """Round with straight through gradients."""
50+ """ round with straight through gradients. """
5151 zhat = z .round ()
5252 return z + (zhat - z ).detach ()
5353
5454def floor_ste (z ):
55+ """ floor with straight through gradients. """
5556 zhat = z .floor ()
5657 return z + (zhat - z ).detach ()
5758
@@ -60,26 +61,26 @@ def floor_ste(z):
6061class FSQ (Module ):
6162 def __init__ (
6263 self ,
63- levels : List [int ],
64+ levels : list [int ],
6465 dim : int | None = None ,
6566 num_codebooks = 1 ,
6667 keep_num_codebooks_dim : bool | None = None ,
6768 scale : float | None = None ,
68- allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
69- channel_first : bool = False ,
70- projection_has_bias : bool = True ,
69+ allowed_dtypes : tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
70+ channel_first = False ,
71+ projection_has_bias = True ,
7172 return_indices = True ,
7273 force_quantization_f32 = True ,
73- preserve_symmetry : bool = False ,
74- noise_dropout = 0.0 ,
74+ preserve_symmetry = False ,
75+ noise_dropout = 0. ,
7576 ):
7677 super ().__init__ ()
7778
78- _levels = torch . tensor (levels , dtype = int32 )
79- self .register_buffer (" _levels" , _levels , persistent = False )
79+ _levels = tensor (levels , dtype = int32 )
80+ self .register_buffer (' _levels' , _levels , persistent = False )
8081
81- _basis = torch .cumprod (torch . tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = int32 )
82- self .register_buffer (" _basis" , _basis , persistent = False )
82+ _basis = torch .cumprod (tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = int32 )
83+ self .register_buffer (' _basis' , _basis , persistent = False )
8384
8485 self .scale = scale
8586
@@ -108,56 +109,65 @@ def __init__(
108109 self .has_projections = has_projections
109110
110111 self .return_indices = return_indices
112+
111113 if return_indices :
112114 self .codebook_size = self ._levels .prod ().item ()
113115 implicit_codebook = self ._indices_to_codes (torch .arange (self .codebook_size ))
114- self .register_buffer (" implicit_codebook" , implicit_codebook , persistent = False )
116+ self .register_buffer (' implicit_codebook' , implicit_codebook , persistent = False )
115117
116118 self .allowed_dtypes = allowed_dtypes
117119 self .force_quantization_f32 = force_quantization_f32
118120
119- def bound (self , z , eps : float = 1e-3 ):
121+ def bound (self , z , eps = 1e-3 ):
120122 """ Bound `z`, an array of shape (..., d). """
121123 half_l = (self ._levels - 1 ) * (1 + eps ) / 2
122124 offset = torch .where (self ._levels % 2 == 0 , 0.5 , 0.0 )
123125 shift = (offset / half_l ).atanh ()
124- return (z + shift ).tanh () * half_l - offset
126+ bounded_z = (z + shift ).tanh () * half_l - offset
127+ half_width = self ._levels // 2
128+ return round_ste (bounded_z ) / half_width
125129
126130 # symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
127131
128132 def symmetry_preserving_bound (self , z ):
129- """
130- QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
131- """
133+ """ QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
132134 levels_minus_1 = (self ._levels - 1 )
133- scale = 2.0 / levels_minus_1
134- bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
135+ scale = 2. / levels_minus_1
136+ bracket = (levels_minus_1 * (z .tanh () + 1 ) / 2. ) + 0.5
135137 bracket = floor_ste (bracket )
136- return scale * bracket - 1.0
138+ return scale * bracket - 1.
137139
138140 def quantize (self , z ):
139141 """ Quantizes z, returns quantized zhat, same shape as z. """
140142
141- shape , device , noise_dropout , preserve_symmetry , half_width = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry , ( self . _levels // 2 )
143+ shape , device , noise_dropout , preserve_symmetry = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry
142144 bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
143145
144146 bounded_z = bound_fn (z )
145147
146148 # determine where to add a random offset elementwise
147149 # if using noise dropout
148150
149- if self .training and noise_dropout > 0. :
150- offset_mask = torch .bernoulli (torch .full_like (bounded_z , noise_dropout )).bool ()
151- offset = torch .rand_like (bounded_z ) - 0.5
152- bounded_z = torch .where (offset_mask , bounded_z + offset , bounded_z )
151+ if not self .training or noise_dropout == 0. :
152+ return bounded_z
153153
154- return round_ste (bounded_z ) / half_width
154+ offset_mask = torch .bernoulli (torch .full_like (bounded_z , noise_dropout )).bool ()
155+ offset = torch .rand_like (bounded_z ) - 0.5
156+ bounded_z = torch .where (offset_mask , bounded_z + offset , bounded_z )
157+
158+ return bounded_z
155159
156160 def _scale_and_shift (self , zhat_normalized ):
161+ if self .preserve_symmetry :
162+ return (zhat_normalized + 1. ) / (2. / (self ._levels - 1 ))
163+
157164 half_width = self ._levels // 2
158165 return (zhat_normalized * half_width ) + half_width
159166
160167 def _scale_and_shift_inverse (self , zhat ):
168+ if self .preserve_symmetry :
169+ return zhat * (2. / (self ._levels - 1 )) - 1.
170+
161171 half_width = self ._levels // 2
162172 return (zhat - half_width ) / half_width
163173
@@ -166,18 +176,18 @@ def _indices_to_codes(self, indices):
166176 codes = self ._scale_and_shift_inverse (level_indices )
167177 return codes
168178
169- def codes_to_indices (self , zhat ):
170- """ Converts a `code` to an index in the codebook. """
171- assert zhat .shape [- 1 ] == self .codebook_dim
172- zhat = self ._scale_and_shift (zhat )
173- return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
174-
175179 def indices_to_level_indices (self , indices ):
176180 """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
177181 indices = rearrange (indices , '... -> ... 1' )
178182 codes_non_centered = (indices // self ._basis ) % self ._levels
179183 return codes_non_centered
180184
185+ def codes_to_indices (self , zhat ):
186+ """ Converts a `code` to an index in the codebook. """
187+ assert zhat .shape [- 1 ] == self .codebook_dim
188+ zhat = self ._scale_and_shift (zhat )
189+ return (zhat * self ._basis ).sum (dim = - 1 ).round ().to (int32 )
190+
181191 def indices_to_codes (self , indices ):
182192 """ Inverse of `codes_to_indices`. """
183193 assert exists (indices )
0 commit comments