Skip to content

Commit 1db491f

Browse files
committed
add graykode#10 ipynb file
1 parent 52c4514 commit 1db491f

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

1-1.NNLM/NNLM_Torch.ipynb

+20-26
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
"metadata": {
2020
"id": "mvlw9p3tPJjr",
2121
"colab_type": "code",
22-
"outputId": "a9d7624b-4a3b-4078-9a89-11c2d6d177d5",
22+
"outputId": "2a4a8f52-315e-42b3-9d49-e9c7c3358979",
2323
"colab": {
2424
"base_uri": "https://localhost:8080/",
25-
"height": 161
25+
"height": 129
2626
}
2727
},
2828
"cell_type": "code",
@@ -45,8 +45,9 @@
4545
"n_class = len(word_dict) # number of Vocabulary\n",
4646
"\n",
4747
"# NNLM Parameter\n",
48-
"n_step = 2 # n-1\n",
49-
"n_hidden = 2 # h\n",
48+
"n_step = 2 # n-1 in paper\n",
49+
"n_hidden = 2 # h in paper\n",
50+
"m = 2 # m in paper\n",
5051
"\n",
5152
"def make_batch(sentences):\n",
5253
" input_batch = []\n",
@@ -57,7 +58,7 @@
5758
" input = [word_dict[n] for n in word[:-1]]\n",
5859
" target = word_dict[word[-1]]\n",
5960
"\n",
60-
" input_batch.append(np.eye(n_class)[input])\n",
61+
" input_batch.append(input)\n",
6162
" target_batch.append(target)\n",
6263
"\n",
6364
" return input_batch, target_batch\n",
@@ -66,17 +67,18 @@
6667
"class NNLM(nn.Module):\n",
6768
" def __init__(self):\n",
6869
" super(NNLM, self).__init__()\n",
69-
"\n",
70-
" self.H = nn.Parameter(torch.randn(n_step * n_class, n_hidden).type(dtype))\n",
71-
" self.W = nn.Parameter(torch.randn(n_step * n_class, n_class).type(dtype))\n",
70+
" self.C = nn.Embedding(n_class, m)\n",
71+
" self.H = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n",
72+
" self.W = nn.Parameter(torch.randn(n_step * m, n_class).type(dtype))\n",
7273
" self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n",
7374
" self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n",
7475
" self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n",
7576
"\n",
7677
" def forward(self, X):\n",
77-
" input = X.view(-1, n_step * n_class) # [batch_size, n_step * n_class]\n",
78-
" tanh = nn.functional.tanh(self.d + torch.mm(input, self.H)) # [batch_size, n_hidden]\n",
79-
" output = self.b + torch.mm(input, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]\n",
78+
" X = self.C(X)\n",
79+
" X = X.view(-1, n_step * m) # [batch_size, n_step * n_class]\n",
80+
" tanh = torch.tanh(self.d + torch.mm(X, self.H)) # [batch_size, n_hidden]\n",
81+
" output = self.b + torch.mm(X, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]\n",
8082
" return output\n",
8183
"\n",
8284
"model = NNLM()\n",
@@ -85,7 +87,7 @@
8587
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
8688
"\n",
8789
"input_batch, target_batch = make_batch(sentences)\n",
88-
"input_batch = Variable(torch.Tensor(input_batch))\n",
90+
"input_batch = Variable(torch.LongTensor(input_batch))\n",
8991
"target_batch = Variable(torch.LongTensor(target_batch))\n",
9092
"\n",
9193
"# Training\n",
@@ -108,24 +110,16 @@
108110
"# Test\n",
109111
"print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])"
110112
],
111-
"execution_count": 0,
113+
"execution_count": 1,
112114
"outputs": [
113115
{
114116
"output_type": "stream",
115117
"text": [
116-
"/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
117-
" warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
118-
],
119-
"name": "stderr"
120-
},
121-
{
122-
"output_type": "stream",
123-
"text": [
124-
"Epoch: 1000 cost = 0.283353\n",
125-
"Epoch: 2000 cost = 0.058013\n",
126-
"Epoch: 3000 cost = 0.023128\n",
127-
"Epoch: 4000 cost = 0.011383\n",
128-
"Epoch: 5000 cost = 0.006090\n",
118+
"Epoch: 1000 cost = 0.147408\n",
119+
"Epoch: 2000 cost = 0.026562\n",
120+
"Epoch: 3000 cost = 0.010481\n",
121+
"Epoch: 4000 cost = 0.005095\n",
122+
"Epoch: 5000 cost = 0.002696\n",
129123
"[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n"
130124
],
131125
"name": "stdout"

0 commit comments

Comments
 (0)