Skip to content

Commit

Permalink
document the new SimVQ and ResidualSimVQ
Browse files Browse the repository at this point in the history
commit loss weighting for sim vq
  • Loading branch information
lucidrains committed Nov 12, 2024
1 parent f883a07 commit 448c4a5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,47 @@ indices = quantizer(x)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`

### Sim VQ

<img src="./images/simvq.png" width="400px"></img>

A <a href="https://arxiv.org/abs/2411.02038">new ICLR 2025 paper</a> proposes a scheme where the codebook is frozen, and the codes are implicitly generated through a linear projection. The authors claim this setup leads to less codebook collapse as well as easier convergence. I have found this to perform even better when paired with <a href="https://arxiv.org/abs/2410.06424">rotation trick</a> from Fifty et al., and expanding the linear projection to a small one layer MLP. You can experiment with it as so

```python
import torch
from vector_quantize_pytorch import SimVQ

sim_vq = SimVQ(
dim = 512,
codebook_size = 1024
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)
```

For the residual flavor, just import `ResidualSimVQ` instead

```python
import torch
from vector_quantize_pytorch import ResidualSimVQ

residual_sim_vq = ResidualSimVQ(
dim = 512,
num_quantizers = 4,
codebook_size = 1024
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = residual_sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-6)
```

### Finite Scalar Quantization

<img src="./images/fsq.png" width="500px"></img>
Expand Down
Binary file added images/simvq.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.7"
version = "1.20.8"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
7 changes: 6 additions & 1 deletion vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
input_to_quantize_commit_loss_weight = 0.25,
commitment_weight = 1.,
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
):
super().__init__()
Expand Down Expand Up @@ -74,6 +75,10 @@ def __init__(

self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight

# total commitment loss weight

self.commitment_weight = commitment_weight

@property
def codebook(self):
return self.code_transform(self.frozen_codebook)
Expand Down Expand Up @@ -132,7 +137,7 @@ def forward(

indices = inverse_pack(indices, 'b *')

return quantized, indices, commit_loss
return quantized, indices, commit_loss * self.commitment_weight

# main

Expand Down

0 comments on commit 448c4a5

Please sign in to comment.