|
19 | 19 | "metadata": {
|
20 | 20 | "id": "mvlw9p3tPJjr",
|
21 | 21 | "colab_type": "code",
|
22 |
| - "outputId": "a9d7624b-4a3b-4078-9a89-11c2d6d177d5", |
| 22 | + "outputId": "2a4a8f52-315e-42b3-9d49-e9c7c3358979", |
23 | 23 | "colab": {
|
24 | 24 | "base_uri": "https://localhost:8080/",
|
25 |
| - "height": 161 |
| 25 | + "height": 129 |
26 | 26 | }
|
27 | 27 | },
|
28 | 28 | "cell_type": "code",
|
|
45 | 45 | "n_class = len(word_dict) # number of Vocabulary\n",
|
46 | 46 | "\n",
|
47 | 47 | "# NNLM Parameter\n",
|
48 |
| - "n_step = 2 # n-1\n", |
49 |
| - "n_hidden = 2 # h\n", |
| 48 | + "n_step = 2 # n-1 in paper\n", |
| 49 | + "n_hidden = 2 # h in paper\n", |
| 50 | + "m = 2 # m in paper\n", |
50 | 51 | "\n",
|
51 | 52 | "def make_batch(sentences):\n",
|
52 | 53 | " input_batch = []\n",
|
|
57 | 58 | " input = [word_dict[n] for n in word[:-1]]\n",
|
58 | 59 | " target = word_dict[word[-1]]\n",
|
59 | 60 | "\n",
|
60 |
| - " input_batch.append(np.eye(n_class)[input])\n", |
| 61 | + " input_batch.append(input)\n", |
61 | 62 | " target_batch.append(target)\n",
|
62 | 63 | "\n",
|
63 | 64 | " return input_batch, target_batch\n",
|
|
66 | 67 | "class NNLM(nn.Module):\n",
|
67 | 68 | " def __init__(self):\n",
|
68 | 69 | " super(NNLM, self).__init__()\n",
|
69 |
| - "\n", |
70 |
| - " self.H = nn.Parameter(torch.randn(n_step * n_class, n_hidden).type(dtype))\n", |
71 |
| - " self.W = nn.Parameter(torch.randn(n_step * n_class, n_class).type(dtype))\n", |
| 70 | + " self.C = nn.Embedding(n_class, m)\n", |
| 71 | + " self.H = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n", |
| 72 | + " self.W = nn.Parameter(torch.randn(n_step * m, n_class).type(dtype))\n", |
72 | 73 | " self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n",
|
73 | 74 | " self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n",
|
74 | 75 | " self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n",
|
75 | 76 | "\n",
|
76 | 77 | " def forward(self, X):\n",
|
77 |
| - " input = X.view(-1, n_step * n_class) # [batch_size, n_step * n_class]\n", |
78 |
| - " tanh = nn.functional.tanh(self.d + torch.mm(input, self.H)) # [batch_size, n_hidden]\n", |
79 |
| - " output = self.b + torch.mm(input, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]\n", |
| 78 | + " X = self.C(X)\n", |
| 79 | + " X = X.view(-1, n_step * m) # [batch_size, n_step * n_class]\n", |
| 80 | + " tanh = torch.tanh(self.d + torch.mm(X, self.H)) # [batch_size, n_hidden]\n", |
| 81 | + " output = self.b + torch.mm(X, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]\n", |
80 | 82 | " return output\n",
|
81 | 83 | "\n",
|
82 | 84 | "model = NNLM()\n",
|
|
85 | 87 | "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
86 | 88 | "\n",
|
87 | 89 | "input_batch, target_batch = make_batch(sentences)\n",
|
88 |
| - "input_batch = Variable(torch.Tensor(input_batch))\n", |
| 90 | + "input_batch = Variable(torch.LongTensor(input_batch))\n", |
89 | 91 | "target_batch = Variable(torch.LongTensor(target_batch))\n",
|
90 | 92 | "\n",
|
91 | 93 | "# Training\n",
|
|
108 | 110 | "# Test\n",
|
109 | 111 | "print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])"
|
110 | 112 | ],
|
111 |
| - "execution_count": 0, |
| 113 | + "execution_count": 1, |
112 | 114 | "outputs": [
|
113 | 115 | {
|
114 | 116 | "output_type": "stream",
|
115 | 117 | "text": [
|
116 |
| - "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", |
117 |
| - " warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" |
118 |
| - ], |
119 |
| - "name": "stderr" |
120 |
| - }, |
121 |
| - { |
122 |
| - "output_type": "stream", |
123 |
| - "text": [ |
124 |
| - "Epoch: 1000 cost = 0.283353\n", |
125 |
| - "Epoch: 2000 cost = 0.058013\n", |
126 |
| - "Epoch: 3000 cost = 0.023128\n", |
127 |
| - "Epoch: 4000 cost = 0.011383\n", |
128 |
| - "Epoch: 5000 cost = 0.006090\n", |
| 118 | + "Epoch: 1000 cost = 0.147408\n", |
| 119 | + "Epoch: 2000 cost = 0.026562\n", |
| 120 | + "Epoch: 3000 cost = 0.010481\n", |
| 121 | + "Epoch: 4000 cost = 0.005095\n", |
| 122 | + "Epoch: 5000 cost = 0.002696\n", |
129 | 123 | "[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n"
|
130 | 124 | ],
|
131 | 125 | "name": "stdout"
|
|
0 commit comments