Skip to content

Commit

Permalink
throw in an option to use code agnostic commit loss for LFQ, found em…
Browse files Browse the repository at this point in the history
…pirically to work well by @MattMcPartlon
  • Loading branch information
lucidrains committed May 6, 2024
1 parent 190ac99 commit 1ea2ef6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.12"
version = "1.14.15"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
21 changes: 17 additions & 4 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def __init__(
straight_through_activation = nn.Identity(),
num_codebooks = 1,
keep_num_codebooks_dim = None,
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
use_code_agnostic_commit_loss = False
):
super().__init__()

Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
# commitment loss

self.commitment_loss_weight = commitment_loss_weight
self.use_code_agnostic_commit_loss = use_code_agnostic_commit_loss

# for no auxiliary loss, during inference

Expand Down Expand Up @@ -259,8 +261,19 @@ def forward(

# commit loss

if self.training:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
if self.training and self.commitment_loss_weight > 0.:

if self.use_code_agnostic_commit_loss:
# credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337

commit_loss = F.mse_loss(
original_input ** 2,
codebook_value ** 2,
reduction = 'none'
)

else:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')

if exists(mask):
commit_loss = commit_loss[mask]
Expand Down

0 comments on commit 1ea2ef6

Please sign in to comment.