Skip to content

Commit

Permalink
able to return a loss breakdown
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 26, 2024
1 parent 013ff84 commit b8d077d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.31"
version = "1.14.32"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
25 changes: 23 additions & 2 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from collections import namedtuple

import torch
from torch.nn import Module
Expand Down Expand Up @@ -703,6 +704,12 @@ def forward(

# main class

LossBreakdown = namedtuple('LossBreakdown', [
'commitment',
'orthogonal_reg',
'inplace_optimize'
])

class VectorQuantize(Module):
def __init__(
self,
Expand Down Expand Up @@ -829,6 +836,8 @@ def __init__(
self.accept_image_fmap = accept_image_fmap
self.channel_last = channel_last

self.register_buffer('zero', torch.tensor(0.), persistent = False)

@property
def codebook(self):
codebook = self._codebook.embed
Expand Down Expand Up @@ -877,7 +886,8 @@ def forward(
indices = None,
mask = None,
sample_codebook_temp = None,
freeze_codebook = False
freeze_codebook = False,
return_loss_breakdown = False
):
orig_input = x

Expand Down Expand Up @@ -927,6 +937,10 @@ def forward(

quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)

# losses for loss breakdown

commit_loss = orthogonal_reg_loss = inplace_optimize_loss = self.zero

# one step in-place update

if should_inplace_optimize and self.training and not freeze_codebook:
Expand All @@ -947,6 +961,8 @@ def forward(
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()

inplace_optimize_loss = loss

# quantize again

quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
Expand Down Expand Up @@ -1084,4 +1100,9 @@ def calculate_ce_loss(codes):
orig_input
)

return quantize, embed_ind, loss
if not return_loss_breakdown:
return quantize, embed_ind, loss

loss_breakdown = LossBreakdown(commit_loss, orthogonal_reg_loss, inplace_optimize_loss)

return quantize, embed_ind, loss, loss_breakdown

0 comments on commit b8d077d

Please sign in to comment.