Skip to content

Commit

Permalink
Merge pull request graykode#34 from lvyilin/master
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode authored Aug 13, 2020
2 parents cb4881e + e785aa3 commit 443c6e7
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions 2-1.TextCNN/TextCNN-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@ def __init__(self):
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)
self.filter_list = nn.ModuleList(
[nn.Conv2d(1, num_filters, (size, embedding_size), bias=True) for size in filter_sizes])

def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]

pooled_outputs = []
for filter_size in filter_sizes:
for i, conv in enumerate(self.filter_list):
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
h = F.relu(conv(embedded_chars))
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
mp = nn.MaxPool2d((sequence_length - filter_sizes[i] + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled)
Expand Down

0 comments on commit 443c6e7

Please sign in to comment.