From 1ea2ef6a07bce0adcd11275fcfffda732af51388 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 6 May 2024 09:38:34 -0700 Subject: [PATCH] throw in an option to use code agnostic commit loss for LFQ, found empirically to work well by @MattMcPartlon --- pyproject.toml | 2 +- .../lookup_free_quantization.py | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 030f4ec..3c5f154 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.14.12" +version = "1.14.15" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 7d30c5e..695343d 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -62,8 +62,9 @@ def __init__( straight_through_activation = nn.Identity(), num_codebooks = 1, keep_num_codebooks_dim = None, - codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer - frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy + codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer + frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy + use_code_agnostic_commit_loss = False ): super().__init__() @@ -110,6 +111,7 @@ def __init__( # commitment loss self.commitment_loss_weight = commitment_loss_weight + self.use_code_agnostic_commit_loss = use_code_agnostic_commit_loss # for no auxiliary loss, during inference @@ -259,8 +261,19 @@ def forward( # commit loss - if self.training: - commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none') + if self.training and self.commitment_loss_weight > 0.: + + if self.use_code_agnostic_commit_loss: + # credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337 + + commit_loss = F.mse_loss( + original_input ** 2, + codebook_value ** 2, + reduction = 'none' + ) + + else: + commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none') if exists(mask): commit_loss = commit_loss[mask]