Skip to content

Commit

Permalink
Merge pull request clovaai#137 from akarazniewicz/feature/torch_1.4.0…
Browse files Browse the repository at this point in the history
…_grid_sample_breking_change

Handling breaking change in pytorch grid_sample
  • Loading branch information
ku21fan authored Jan 16, 2020
2 parents 24749cd + c529003 commit 382a9b5
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions modules/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class TPS_SpatialTransformerNetwork(nn.Module):
""" Rectification Network of RARE, namely TPS based STN """

Expand All @@ -30,7 +29,11 @@ def forward(self, batch_I):
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')

if torch.__version__ > "1.2.0":
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corner=True)
else:
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')

return batch_I_r

Expand Down

0 comments on commit 382a9b5

Please sign in to comment.