Skip to content

Commit

Permalink
Merge pull request lucidrains#172 from lucidrains/simvq
Browse files Browse the repository at this point in the history
SimVQ
  • Loading branch information
lucidrains authored Nov 11, 2024
2 parents 723ea9f + 97b9a87 commit 72ede73
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 7 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,12 @@ assert loss.item() >= 0
url = {https://api.semanticscholar.org/CorpusID:273229218}
}
```

```bibtex
@inproceedings{Zhu2024AddressingRC,
title = {Addressing Representation Collapse in Vector Quantized Models with One Linear Layer},
author = {Yongxin Zhu and Bocheng Li and Yifei Xin and Linli Xu},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273812459}
}
```
6 changes: 3 additions & 3 deletions examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def iterate_dataset(data_loader):
shuffle=True,
)

print("baseline")
torch.random.manual_seed(seed)

model = SimpleVQAutoEncoder(
codebook_size=num_codes,
rotation_trick=rotation_trick
codebook_size = num_codes,
rotation_trick = True,
straight_through = False
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
Expand Down
1 change: 0 additions & 1 deletion examples/autoencoder_fsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def iterate_dataset(data_loader):
shuffle=True,
)

print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
Expand Down
2 changes: 0 additions & 2 deletions examples/autoencoder_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def iterate_dataset(data_loader):
shuffle=True,
)

print("baseline")

torch.random.manual_seed(seed)

model = LFQAutoEncoder(
Expand Down
84 changes: 84 additions & 0 deletions examples/autoencoder_sim_vq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# FashionMnist VQ experiment with various settings.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py

from tqdm.auto import trange

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import SimVQ, Sequential

lr = 3e-4
train_iter = 10000
num_codes = 256
seed = 1234
rotation_trick = True
device = "cuda" if torch.cuda.is_available() else "cpu"

def SimVQAutoEncoder(**vq_kwargs):
return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)

def train(model, train_loader, train_iterations=1000, alpha=10):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))

out, indices, cmt_loss = model(x)
out = out.clamp(-1., 1.)

rec_loss = (out - x).abs().mean()
(rec_loss + alpha * cmt_loss).backward()

opt.step()

pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"cmt loss: {cmt_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)

torch.random.manual_seed(seed)

model = SimVQAutoEncoder(
codebook_size = num_codes,
rotation_trick = rotation_trick
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.19.5"
version = "1.20.0"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
1 change: 1 addition & 0 deletions vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
from vector_quantize_pytorch.latent_quantization import LatentQuantize
from vector_quantize_pytorch.sim_vq import SimVQ

from vector_quantize_pytorch.utils import Sequential
124 changes: 124 additions & 0 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Callable

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from einx import get_at
from einops import rearrange, pack, unpack

from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to

# helper functions

def exists(v):
return v is not None

def identity(t):
return t

def default(v, d):
return v if exists(v) else d

def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)

def inverse(out, inv_pattern = None):
inv_pattern = default(inv_pattern, pattern)
out, = unpack(out, packed_shape, inv_pattern)
return out

return packed, inverse

# class

class SimVQ(Module):
def __init__(
self,
dim,
codebook_size,
init_fn: Callable = identity,
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
commit_loss_input_to_quantize_weight = 0.25,
):
super().__init__()
self.accept_image_fmap = accept_image_fmap

codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
codebook = init_fn(codebook)

# the codebook is actually implicit from a linear layer from frozen gaussian or uniform

self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
self.register_buffer('codebook', codebook)


# whether to use rotation trick from Fifty et al.
# https://arxiv.org/abs/2410.06424

self.rotation_trick = rotation_trick
self.register_buffer('zero', torch.tensor(0.), persistent = False)

# commit loss weighting - weighing input to quantize a bit less is crucial for it to work

self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight

def forward(
self,
x
):
if self.accept_image_fmap:
x = rearrange(x, 'b d h w -> b h w d')
x, inverse_pack = pack_one(x, 'b * d')

implicit_codebook = self.codebook_to_codes(self.codebook)

with torch.no_grad():
dist = torch.cdist(x, implicit_codebook)
indices = dist.argmin(dim = -1)

# select codes

quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)

if self.rotation_trick:
# rotation trick from @cfifty

quantized = rotate_from_to(quantized, x)

commit_loss = self.zero
else:
# commit loss and straight through, as was done in the paper

commit_loss = (
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight +
F.mse_loss(x.detach(), quantized)
)

quantized = (quantized - x).detach() + x

if self.accept_image_fmap:
quantized = inverse_pack(quantized)
quantized = rearrange(quantized, 'b h w d-> b d h w')

indices = inverse_pack(indices, 'b *')

return quantized, indices, commit_loss

# main

if __name__ == '__main__':

x = torch.randn(1, 512, 32, 32)

sim_vq = SimVQ(
dim = 512,
codebook_size = 1024,
accept_image_fmap = True
)

quantized, indices, commit_loss = sim_vq(x)

assert x.shape == quantized.shape
2 changes: 2 additions & 0 deletions vector_quantize_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
from vector_quantize_pytorch.latent_quantization import LatentQuantize
from vector_quantize_pytorch.sim_vq import SimVQ

QUANTIZE_KLASSES = (
VectorQuantize,
Expand All @@ -20,6 +21,7 @@
RandomProjectionQuantizer,
FSQ,
LFQ,
SimVQ,
ResidualLFQ,
GroupedResidualLFQ,
ResidualFSQ,
Expand Down

0 comments on commit 72ede73

Please sign in to comment.