Skip to content

Commit

Permalink
add .indices_to_codes for SimVQ
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent 3bb00f5 commit a0e8f2c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.2"
version = "1.20.3"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
14 changes: 14 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,17 @@ def test_latent_q():

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

def test_sim_vq():
from vector_quantize_pytorch import SimVQ

sim_vq = SimVQ(
dim = 512,
codebook_size = 1024,
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)
21 changes: 19 additions & 2 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

self.codebook_to_codes = codebook_transform

self.register_buffer('codebook', codebook)
self.register_buffer('frozen_codebook', codebook)


# whether to use rotation trick from Fifty et al.
Expand All @@ -70,6 +70,23 @@ def __init__(

self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight

@property
def codebook(self):
return self.codebook_to_codes(self.frozen_codebook)

def indices_to_codes(
self,
indices
):
implicit_codebook = self.codebook

quantized = get_at('[c] d, b ... -> b ... d', implicit_codebook, indices)

if self.accept_image_fmap:
quantized = rearrange(quantized, 'b ... d -> b d ...')

return quantized

def forward(
self,
x
Expand All @@ -78,7 +95,7 @@ def forward(
x = rearrange(x, 'b d h w -> b h w d')
x, inverse_pack = pack_one(x, 'b * d')

implicit_codebook = self.codebook_to_codes(self.codebook)
implicit_codebook = self.codebook

with torch.no_grad():
dist = torch.cdist(x, implicit_codebook)
Expand Down

0 comments on commit a0e8f2c

Please sign in to comment.