Skip to content

Commit

Permalink
update init
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent 7959292 commit 007209d
Showing 1 changed file with 2 additions and 39 deletions.
41 changes: 2 additions & 39 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,6 @@ def inverse(out, inv_pattern = None):

return packed, inverse

def l2norm(t, dim = -1):
return F.normalize(t, dim = dim)

def safe_div(num, den, eps = 1e-6):
return num / den.clamp(min = eps)

def efficient_rotation_trick_transform(u, q, e):
"""
4.2 in https://arxiv.org/abs/2410.06424
"""
e = rearrange(e, 'b d -> b 1 d')
w = l2norm(u + q, dim = 1).detach()

return (
e -
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
)

# class

class SimVQ(Module):
Expand All @@ -61,7 +42,7 @@ def __init__(
super().__init__()
self.accept_image_fmap = accept_image_fmap

codebook = torch.randn(codebook_size, dim)
codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
codebook = init_fn(codebook)

# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
Expand Down Expand Up @@ -89,25 +70,7 @@ def forward(

# commit loss

commit_loss = (F.pairwise_distance(x, quantized.detach()) ** 2).mean()

# straight through

x, inverse = pack_one(x, '* d')
quantized, _ = pack_one(quantized, '* d')

norm_x = x.norm(dim = -1, keepdim = True)
norm_quantize = quantized.norm(dim = -1, keepdim = True)

rot_quantize = efficient_rotation_trick_transform(
safe_div(x, norm_x),
safe_div(quantized, norm_quantize),
x
).squeeze()

quantized = rot_quantize * safe_div(norm_quantize, norm_x).detach()

x, quantized = inverse(x), inverse(quantized)
commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean()

# quantized = (quantized - x).detach() + x

Expand Down

0 comments on commit 007209d

Please sign in to comment.