Skip to content

Commit

Permalink
fix minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Apr 7, 2019
1 parent e4249be commit 34e6316
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_c
else:
raise Exception('Prediction is neither CTC or Attn')

def forward(self, input, length, text, is_train=True):
def forward(self, input, text, is_train=True):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)
Expand All @@ -87,6 +87,6 @@ def forward(self, input, length, text, is_train=True):
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(contextual_feature.contiguous(), length, text, is_train)
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train)

return prediction
3 changes: 1 addition & 2 deletions modules/prediction.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ def _char_to_onehot(self, input_char, onehot_dim=38):
one_hot = one_hot.scatter_(1, input_char, 1)
return one_hot

def forward(self, batch_H, length, text, is_train=True, batch_max_length=25):
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
"""
input:
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
length : the length of each label. train: [3, 7, ....], test: [25, 25, 25, ...] [batch_size]
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
output: probability distribution at each step [batch_size x num_steps x num_classes]
"""
Expand Down
4 changes: 2 additions & 2 deletions test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):

start_time = time.time()
if 'CTC' in opt.Prediction:
preds = model(image, length_for_pred, text_for_pred)
preds = model(image, text_for_pred)
forward_time = time.time() - start_time

# Calculate evaluation loss for CTC deocder.
Expand All @@ -99,7 +99,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
sim_preds = converter.decode(preds.data, preds_size.data)

else:
preds = model(image, length_for_pred, text_for_pred, is_train=False)
preds = model(image, text_for_pred, is_train=False)
forward_time = time.time() - start_time

preds = preds[:, :text_for_loss.shape[1] - 1, :]
Expand Down
4 changes: 2 additions & 2 deletions train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def train(opt):
batch_size = image.size(0)

if 'CTC' in opt.Prediction:
preds = model(image, length, text)
preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size

else:
preds = model(image, length, text)
preds = model(image, text)
target = text[:, 1:] # without [GO] Symbol
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

Expand Down

0 comments on commit 34e6316

Please sign in to comment.