Skip to content

Commit

Permalink
do not do straight through nor rotation trick if input does not requi…
Browse files Browse the repository at this point in the history
…re grad, to make Genie2 cleaner
  • Loading branch information
lucidrains committed Jan 7, 2025
1 parent c243e83 commit fa2211d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.21.0"
version = "1.21.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
8 changes: 7 additions & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ def exists(v):

@pytest.mark.parametrize('use_cosine_sim', (True, False))
@pytest.mark.parametrize('rotation_trick', (True, False))
@pytest.mark.parametrize('input_requires_grad', (True, False))
def test_vq(
use_cosine_sim,
rotation_trick
rotation_trick,
input_requires_grad
):
from vector_quantize_pytorch import VectorQuantize

Expand All @@ -22,6 +24,10 @@ def test_vq(
)

x = torch.randn(1, 1024, 256)

if input_requires_grad:
x.requires_grad_()

quantized, indices, commit_loss = vq(x)

def test_vq_eval():
Expand Down
15 changes: 9 additions & 6 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def forward(
return_loss_breakdown = False,
codebook_transform_fn: Callable | None = None
):
orig_input = x
orig_input, input_requires_grad = x, x.requires_grad

# handle masking, either passed in as `mask` or `lens`

Expand Down Expand Up @@ -1117,11 +1117,14 @@ def forward(

commit_quantize = maybe_detach(quantize)

if self.rotation_trick:
quantize = rotate_to(x, quantize)
else:
# standard STE to get gradients through VQ layer.
quantize = x + (quantize - x).detach()
# spare rotation trick calculation if inputs do not need gradients

if input_requires_grad:
if self.rotation_trick:
quantize = rotate_to(x, quantize)
else:
# standard STE to get gradients through VQ layer.
quantize = x + (quantize - x).detach()

if self.sync_update_v > 0.:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
Expand Down

0 comments on commit fa2211d

Please sign in to comment.