Skip to content

Commit

Permalink
add a test for get_output_from_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 18, 2024
1 parent 7444ddc commit ea3b16d
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ def test_vq():
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

def test_vq_eval():
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
dim = 256,
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1. # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)

vq.eval()
quantized, indices, commit_loss = vq(x)
assert torch.allclose(quantized, vq.get_output_from_indices(indices))

def test_residual_vq():
import torch
Expand Down

0 comments on commit ea3b16d

Please sign in to comment.