Skip to content

Commit

Permalink
Merge pull request lucidrains#186 from lucasnewman/fsq-noise-dropout
Browse files Browse the repository at this point in the history
FSQ: Use element-wise selection for noise dropout
  • Loading branch information
lucidrains authored Jan 10, 2025
2 parents fa2211d + 708f3c9 commit 4f0fc17
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
return_indices = True,
force_quantization_f32 = True,
preserve_symmetry: bool = False,
noise_approx_prob = 0.0,
noise_dropout = 0.0,
):
super().__init__()

Expand All @@ -79,7 +79,7 @@ def __init__(
self.scale = scale

self.preserve_symmetry = preserve_symmetry
self.noise_approx_prob = noise_approx_prob
self.noise_dropout = noise_dropout

codebook_dim = len(levels)
self.codebook_dim = codebook_dim
Expand Down Expand Up @@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z):
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
return scale * bracket - 1.0

def noise_approx_bound(self, z):
"""
simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
"""
noise = torch.empty_like(z).uniform_(-1, 1)
return torch.tanh(z) + noise / (self._levels - 1)

def quantize(self, z, preserve_symmetry = False):
""" Quantizes z, returns quantized zhat, same shape as z. """
if self.training and random.random() < self.noise_approx_prob:
bounded = self.noise_approx_bound(z)

half_width = self._levels // 2

if self.training:
unquantized = z

# determine where to quantize elementwise

quantize_mask = torch.bernoulli(
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
).bool().expand_as(z)

if preserve_symmetry:
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
else:
quantized = round_ste(self.bound(z)) / half_width
quantized = torch.where(quantize_mask, unquantized, quantized)

# determine where to add a random offset elementwise

offset_mask = torch.bernoulli(
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
).bool().expand_as(z)

offset = (torch.rand_like(z) - 0.5) / half_width
quantized = torch.where(offset_mask, unquantized + offset, quantized)
elif preserve_symmetry:
bounded = self.symmetry_preserving_bound(z)
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
else:
bounded = self.bound(z)
quantized = round_ste(bounded)
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
quantized = round_ste(self.bound(z)) / half_width

return quantized

def _scale_and_shift(self, zhat_normalized):
half_width = self._levels // 2
Expand Down

0 comments on commit 4f0fc17

Please sign in to comment.