Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent 80f4e84 commit 949b0ba
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from einx import get_at
from einops import rearrange, pack, unpack

from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_to

# helper functions

Expand Down Expand Up @@ -94,7 +94,7 @@ def forward(

if self.rotation_trick:
# rotation trick from @cfifty
quantized = rotate_from_to(quantized, x)
quantized = rotate_to(x, quantized)
else:

commit_loss = (
Expand Down
18 changes: 9 additions & 9 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,21 +250,21 @@ def efficient_rotation_trick_transform(u, q, e):
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
)

def rotate_from_to(src, tgt):
def rotate_to(src, tgt):
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
tgt, inverse = pack_one(tgt, '* d')
src, _ = pack_one(src, '* d')
src, inverse = pack_one(src, '* d')
tgt, _ = pack_one(tgt, '* d')

norm_tgt = tgt.norm(dim = -1, keepdim = True)
norm_src = src.norm(dim = -1, keepdim = True)
norm_tgt = tgt.norm(dim = -1, keepdim = True)

rotated_src = efficient_rotation_trick_transform(
safe_div(tgt, norm_tgt),
rotated_tgt = efficient_rotation_trick_transform(
safe_div(src, norm_src),
tgt
safe_div(tgt, norm_tgt),
src
).squeeze()

rotated = rotated_src * safe_div(norm_src, norm_tgt).detach()
rotated = rotated_tgt * safe_div(norm_tgt, norm_src).detach()

return inverse(rotated)

Expand Down Expand Up @@ -1118,7 +1118,7 @@ def forward(
commit_quantize = maybe_detach(quantize)

if self.rotation_trick:
quantize = rotate_from_to(quantize, x)
quantize = rotate_to(x, quantize)
else:
# standard STE to get gradients through VQ layer.
quantize = x + (quantize - x).detach()
Expand Down

0 comments on commit 949b0ba

Please sign in to comment.