@@ -39,14 +39,15 @@ def __init__(self):
39
39
super (NNLM , self ).__init__ ()
40
40
41
41
self .H = nn .Parameter (torch .randn (n_step * n_class , n_hidden ).type (dtype ))
42
+ self .W = nn .Parameter (torch .randn (n_step * n_class , n_class ).type (dtype ))
42
43
self .d = nn .Parameter (torch .randn (n_hidden ).type (dtype ))
43
44
self .U = nn .Parameter (torch .randn (n_hidden , n_class ).type (dtype ))
44
45
self .b = nn .Parameter (torch .randn (n_class ).type (dtype ))
45
46
46
47
def forward (self , X ):
47
48
input = X .view (- 1 , n_step * n_class ) # [batch_size, n_step * n_class]
48
49
tanh = nn .functional .tanh (self .d + torch .mm (input , self .H )) # [batch_size, n_hidden]
49
- output = torch .mm (tanh , self .U ) + self .b # [batch_size, n_class]
50
+ output = self . b + torch .mm (input , self .W ) + torch . mm ( tanh , self .U ) # [batch_size, n_class]
50
51
return output
51
52
52
53
model = NNLM ()
@@ -76,4 +77,4 @@ def forward(self, X):
76
77
predict = model (input_batch ).data .max (1 , keepdim = True )[1 ]
77
78
78
79
# Test
79
- print ([sen .split ()[:2 ] for sen in sentences ], '->' , [number_dict [n .item ()] for n in predict .squeeze ()])
80
+ print ([sen .split ()[:2 ] for sen in sentences ], '->' , [number_dict [n .item ()] for n in predict .squeeze ()])
0 commit comments