Skip to content

Commit

Permalink
'fix_ctc_loss_issue'
Browse files Browse the repository at this point in the history
  • Loading branch information
ku21fan committed Aug 7, 2019
1 parent 8f7255f commit 1c6efa5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
5 changes: 5 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCloss format

# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
torch.backends.cudnn.enabled = True

# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
Expand Down
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,14 @@ def train(opt):

if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device)
preds = preds.permute(1, 0, 2) # to use CTCLoss format

# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text, preds_size, length)
torch.backends.cudnn.enabled = True

else:
preds = model(image, text[:, :-1]) # align with Attention.forward
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def encode(self, text, batch_max_length=25):
text = ''.join(text)
text = [self.dict[char] for char in text]

return (torch.IntTensor(text), torch.IntTensor(length))
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))

def decode(self, text_index, length):
""" convert text-index into text-label. """
Expand Down

0 comments on commit 1c6efa5

Please sign in to comment.