Skip to content

Commit

Permalink
offer a way to turn off distributed replacement of codes during expir…
Browse files Browse the repository at this point in the history
…ation for now
  • Loading branch information
lucidrains committed Jun 25, 2024
1 parent 133f738 commit 013ff84
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.30"
version = "1.14.31"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
12 changes: 8 additions & 4 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
def exists(v):
return v is not None

def test_vq():
@pytest.mark.parametrize('use_cosine_sim', (True, False))
def test_vq(
use_cosine_sim
):
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
dim = 256,
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1. # the weight on the commitment loss
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1., # the weight on the commitment loss
use_cosine_sim = use_cosine_sim
)

x = torch.randn(1, 1024, 256)
Expand Down
23 changes: 16 additions & 7 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

import torch
from torch.nn import Module
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed
Expand Down Expand Up @@ -245,7 +246,7 @@ def orthogonal_loss_fn(t):

# distance types

class EuclideanCodebook(nn.Module):
class EuclideanCodebook(Module):
def __init__(
self,
dim,
Expand All @@ -259,6 +260,7 @@ def __init__(
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
distributed_replace_codes = True,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
Expand Down Expand Up @@ -292,6 +294,8 @@ def __init__(
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

Expand Down Expand Up @@ -422,7 +426,7 @@ def replace(self, batch_samples, batch_mask):
if not torch.any(mask):
continue

sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

self.embed.data[ind][mask] = sampled
Expand Down Expand Up @@ -520,7 +524,7 @@ def forward(

return quantize, embed_ind, dist

class CosineSimCodebook(nn.Module):
class CosineSimCodebook(Module):
def __init__(
self,
dim,
Expand All @@ -534,10 +538,11 @@ def __init__(
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
distributed_replace_codes = True,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True
ema_update = True,
):
super().__init__()
self.transform_input = l2norm
Expand All @@ -563,6 +568,8 @@ def __init__(
self.sample_codebook_temp = sample_codebook_temp

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

Expand Down Expand Up @@ -608,7 +615,7 @@ def replace(self, batch_samples, batch_mask):
if not torch.any(mask):
continue

sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

self.embed.data[ind][mask] = sampled
Expand Down Expand Up @@ -696,7 +703,7 @@ def forward(

# main class

class VectorQuantize(nn.Module):
class VectorQuantize(Module):
def __init__(
self,
dim,
Expand All @@ -723,6 +730,7 @@ def __init__(
stochastic_sample_codes = False,
sample_codebook_temp = 1.,
straight_through = False,
distributed_replace_codes = True,
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
sync_codebook = None,
sync_affine_param = False,
Expand Down Expand Up @@ -798,7 +806,8 @@ def __init__(
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
sample_codebook_temp = sample_codebook_temp,
gumbel_sample = gumbel_sample_fn,
ema_update = ema_update
ema_update = ema_update,
distributed_replace_codes = distributed_replace_codes
)

if affine_param:
Expand Down

0 comments on commit 013ff84

Please sign in to comment.