From b8d077d4f53f5ade4075a26da6d8f0bea6fe85f1 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 26 Jun 2024 06:45:52 -0700 Subject: [PATCH] able to return a loss breakdown --- pyproject.toml | 2 +- .../vector_quantize_pytorch.py | 25 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4f703c..454687f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.14.31" +version = "1.14.32" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 0f647c0..63fd7e8 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -1,4 +1,5 @@ from functools import partial +from collections import namedtuple import torch from torch.nn import Module @@ -703,6 +704,12 @@ def forward( # main class +LossBreakdown = namedtuple('LossBreakdown', [ + 'commitment', + 'orthogonal_reg', + 'inplace_optimize' +]) + class VectorQuantize(Module): def __init__( self, @@ -829,6 +836,8 @@ def __init__( self.accept_image_fmap = accept_image_fmap self.channel_last = channel_last + self.register_buffer('zero', torch.tensor(0.), persistent = False) + @property def codebook(self): codebook = self._codebook.embed @@ -877,7 +886,8 @@ def forward( indices = None, mask = None, sample_codebook_temp = None, - freeze_codebook = False + freeze_codebook = False, + return_loss_breakdown = False ): orig_input = x @@ -927,6 +937,10 @@ def forward( quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + # losses for loss breakdown + + commit_loss = orthogonal_reg_loss = inplace_optimize_loss = self.zero + # one step in-place update if should_inplace_optimize and self.training and not freeze_codebook: @@ -947,6 +961,8 @@ def forward( self.in_place_codebook_optimizer.step() self.in_place_codebook_optimizer.zero_grad() + inplace_optimize_loss = loss + # quantize again quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) @@ -1084,4 +1100,9 @@ def calculate_ce_loss(codes): orig_input ) - return quantize, embed_ind, loss + if not return_loss_breakdown: + return quantize, embed_ind, loss + + loss_breakdown = LossBreakdown(commit_loss, orthogonal_reg_loss, inplace_optimize_loss) + + return quantize, embed_ind, loss, loss_breakdown