Skip to content

Commit c0674ee

Browse files
authored
Merge pull request graykode#4 from likejazz/patch-1
Add `W` parameter as written in paper.
2 parents 0570614 + 5027be2 commit c0674ee

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

1-1.NNLM/NNLM-Torch.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,15 @@ def __init__(self):
3939
super(NNLM, self).__init__()
4040

4141
self.H = nn.Parameter(torch.randn(n_step * n_class, n_hidden).type(dtype))
42+
self.W = nn.Parameter(torch.randn(n_step * n_class, n_class).type(dtype))
4243
self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))
4344
self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))
4445
self.b = nn.Parameter(torch.randn(n_class).type(dtype))
4546

4647
def forward(self, X):
4748
input = X.view(-1, n_step * n_class) # [batch_size, n_step * n_class]
4849
tanh = nn.functional.tanh(self.d + torch.mm(input, self.H)) # [batch_size, n_hidden]
49-
output = torch.mm(tanh, self.U) + self.b # [batch_size, n_class]
50+
output = self.b + torch.mm(input, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]
5051
return output
5152

5253
model = NNLM()
@@ -76,4 +77,4 @@ def forward(self, X):
7677
predict = model(input_batch).data.max(1, keepdim=True)[1]
7778

7879
# Test
79-
print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])
80+
print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])

0 commit comments

Comments
 (0)