From 7b88639e7efd692bd952d4038f56a71581b5ed37 Mon Sep 17 00:00:00 2001 From: Hanqing Liu Date: Mon, 29 Jul 2024 19:43:49 -0700 Subject: [PATCH] just some notes --- .../vector_quantize_pytorch.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 0ad4c9d..262e705 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -39,6 +39,13 @@ def Sequential(*modules): return nn.Sequential(*modules) def cdist(x, y): + """ + Calculate euc distance sqrt((x - y) ** 2) + + x: (b, i, d) + y: (b, j, d) + returns: (b, i, j) + """ x2 = reduce(x ** 2, 'b n d -> b n', 'sum') y2 = reduce(y ** 2, 'b n d -> b n', 'sum') xy = einsum('b i d, b j d -> b i j', x, y) * -2 @@ -788,12 +795,14 @@ def __init__( self.learnable_codebook = learnable_codebook + # orthogonal loss aims to make the active codebook codes orthogonal to each other has_codebook_orthogonal_loss = orthogonal_reg_weight > 0. self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss self.orthogonal_reg_weight = orthogonal_reg_weight self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + # codebook diversity loss aims to make more diverse codebook codes being used has_codebook_diversity_loss = codebook_diversity_loss_weight > 0. self.has_codebook_diversity_loss = has_codebook_diversity_loss self.codebook_diversity_temperature = codebook_diversity_temperature @@ -898,6 +907,7 @@ def get_codes_from_indices(self, indices): return codes def get_output_from_indices(self, indices): + """This function is used when only indices are procided and the output is needed""" codes = self.get_codes_from_indices(indices) return self.project_out(codes) @@ -933,6 +943,12 @@ def forward( need_transpose = not self.channel_last and not self.accept_image_fmap should_inplace_optimize = exists(self.in_place_codebook_optimizer) + # einops notions + # b: batch size + # d: dimension of input X + # n: number of tokens, for 2-D input, n = 1 + # h: number of heads for multi-headed codebook + # rearrange inputs if self.accept_image_fmap: @@ -944,7 +960,7 @@ def forward( x = rearrange(x, 'b d n -> b n d') # project input - + # A simple linear layer with LayerNorm (optional) x = self.project_in(x) # handle multi-headed separate codebooks @@ -1005,13 +1021,18 @@ def forward( commit_quantize = maybe_detach(quantize) - # straight through - + # The straight through trick + # This is the trick to make the gradient pass through the quantized code during backpropagation + # By adding x to the detached difference (quantize - x).detach(), + # you effectively make quantize equal to x plus a constant offset in the backward pass. + # Since x is a part of the computational graph, gradients computed with + # respect to quantize will now pass directly to x. quantize = x + (quantize - x).detach() if self.sync_update_v > 0.: # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + # this sync_update_v has effect on the gradient, but not on the forward pass # function for calculating cross entropy loss to distance matrix # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss @@ -1057,6 +1078,8 @@ def calculate_ce_loss(codes): if self.training: # calculate codebook diversity loss (negative of entropy) if needed + # codebook diversity calculates the entropy of the distance distribution of the codes for each token + # the idea is to encourage the codebook to have a diverse set of codes for each token if self.has_codebook_diversity_loss: prob = (-distances * self.codebook_diversity_temperature).softmax(dim = -1)