forked from lucidrains/vector-quantize-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request lucidrains#172 from lucidrains/simvq
SimVQ
- Loading branch information
Showing
9 changed files
with
224 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# FashionMnist VQ experiment with various settings. | ||
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py | ||
|
||
from tqdm.auto import trange | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchvision import datasets, transforms | ||
from torch.utils.data import DataLoader | ||
|
||
from vector_quantize_pytorch import SimVQ, Sequential | ||
|
||
lr = 3e-4 | ||
train_iter = 10000 | ||
num_codes = 256 | ||
seed = 1234 | ||
rotation_trick = True | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
def SimVQAutoEncoder(**vq_kwargs): | ||
return Sequential( | ||
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
nn.GELU(), | ||
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs), | ||
nn.Upsample(scale_factor=2, mode="nearest"), | ||
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), | ||
nn.GELU(), | ||
nn.Upsample(scale_factor=2, mode="nearest"), | ||
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), | ||
) | ||
|
||
def train(model, train_loader, train_iterations=1000, alpha=10): | ||
def iterate_dataset(data_loader): | ||
data_iter = iter(data_loader) | ||
while True: | ||
try: | ||
x, y = next(data_iter) | ||
except StopIteration: | ||
data_iter = iter(data_loader) | ||
x, y = next(data_iter) | ||
yield x.to(device), y.to(device) | ||
|
||
for _ in (pbar := trange(train_iterations)): | ||
opt.zero_grad() | ||
x, _ = next(iterate_dataset(train_loader)) | ||
|
||
out, indices, cmt_loss = model(x) | ||
out = out.clamp(-1., 1.) | ||
|
||
rec_loss = (out - x).abs().mean() | ||
(rec_loss + alpha * cmt_loss).backward() | ||
|
||
opt.step() | ||
|
||
pbar.set_description( | ||
f"rec loss: {rec_loss.item():.3f} | " | ||
+ f"cmt loss: {cmt_loss.item():.3f} | " | ||
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}" | ||
) | ||
|
||
transform = transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] | ||
) | ||
|
||
train_dataset = DataLoader( | ||
datasets.FashionMNIST( | ||
root="~/data/fashion_mnist", train=True, download=True, transform=transform | ||
), | ||
batch_size=256, | ||
shuffle=True, | ||
) | ||
|
||
torch.random.manual_seed(seed) | ||
|
||
model = SimVQAutoEncoder( | ||
codebook_size = num_codes, | ||
rotation_trick = rotation_trick | ||
).to(device) | ||
|
||
opt = torch.optim.AdamW(model.parameters(), lr=lr) | ||
train(model, train_dataset, train_iterations=train_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[project] | ||
name = "vector-quantize-pytorch" | ||
version = "1.19.5" | ||
version = "1.20.0" | ||
description = "Vector Quantization - Pytorch" | ||
authors = [ | ||
{ name = "Phil Wang", email = "[email protected]" } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from typing import Callable | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import Module | ||
import torch.nn.functional as F | ||
|
||
from einx import get_at | ||
from einops import rearrange, pack, unpack | ||
|
||
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to | ||
|
||
# helper functions | ||
|
||
def exists(v): | ||
return v is not None | ||
|
||
def identity(t): | ||
return t | ||
|
||
def default(v, d): | ||
return v if exists(v) else d | ||
|
||
def pack_one(t, pattern): | ||
packed, packed_shape = pack([t], pattern) | ||
|
||
def inverse(out, inv_pattern = None): | ||
inv_pattern = default(inv_pattern, pattern) | ||
out, = unpack(out, packed_shape, inv_pattern) | ||
return out | ||
|
||
return packed, inverse | ||
|
||
# class | ||
|
||
class SimVQ(Module): | ||
def __init__( | ||
self, | ||
dim, | ||
codebook_size, | ||
init_fn: Callable = identity, | ||
accept_image_fmap = False, | ||
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through | ||
commit_loss_input_to_quantize_weight = 0.25, | ||
): | ||
super().__init__() | ||
self.accept_image_fmap = accept_image_fmap | ||
|
||
codebook = torch.randn(codebook_size, dim) * (dim ** -0.5) | ||
codebook = init_fn(codebook) | ||
|
||
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform | ||
|
||
self.codebook_to_codes = nn.Linear(dim, dim, bias = False) | ||
self.register_buffer('codebook', codebook) | ||
|
||
|
||
# whether to use rotation trick from Fifty et al. | ||
# https://arxiv.org/abs/2410.06424 | ||
|
||
self.rotation_trick = rotation_trick | ||
self.register_buffer('zero', torch.tensor(0.), persistent = False) | ||
|
||
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work | ||
|
||
self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight | ||
|
||
def forward( | ||
self, | ||
x | ||
): | ||
if self.accept_image_fmap: | ||
x = rearrange(x, 'b d h w -> b h w d') | ||
x, inverse_pack = pack_one(x, 'b * d') | ||
|
||
implicit_codebook = self.codebook_to_codes(self.codebook) | ||
|
||
with torch.no_grad(): | ||
dist = torch.cdist(x, implicit_codebook) | ||
indices = dist.argmin(dim = -1) | ||
|
||
# select codes | ||
|
||
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) | ||
|
||
if self.rotation_trick: | ||
# rotation trick from @cfifty | ||
|
||
quantized = rotate_from_to(quantized, x) | ||
|
||
commit_loss = self.zero | ||
else: | ||
# commit loss and straight through, as was done in the paper | ||
|
||
commit_loss = ( | ||
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight + | ||
F.mse_loss(x.detach(), quantized) | ||
) | ||
|
||
quantized = (quantized - x).detach() + x | ||
|
||
if self.accept_image_fmap: | ||
quantized = inverse_pack(quantized) | ||
quantized = rearrange(quantized, 'b h w d-> b d h w') | ||
|
||
indices = inverse_pack(indices, 'b *') | ||
|
||
return quantized, indices, commit_loss | ||
|
||
# main | ||
|
||
if __name__ == '__main__': | ||
|
||
x = torch.randn(1, 512, 32, 32) | ||
|
||
sim_vq = SimVQ( | ||
dim = 512, | ||
codebook_size = 1024, | ||
accept_image_fmap = True | ||
) | ||
|
||
quantized, indices, commit_loss = sim_vq(x) | ||
|
||
assert x.shape == quantized.shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters