{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TextLSTM-Tensor.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "metadata": {
        "id": "HggP9sCLZUrt",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "ee905d35-c761-4ce1-bd06-f0313843b4a9"
      },
      "cell_type": "code",
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode\n",
        "'''\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "tf.reset_default_graph()\n",
        "\n",
        "char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']\n",
        "word_dict = {n: i for i, n in enumerate(char_arr)}\n",
        "number_dict = {i: w for i, w in enumerate(char_arr)}\n",
        "n_class = len(word_dict) # number of class(=number of vocab)\n",
        "\n",
        "seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']\n",
        "\n",
        "# TextLSTM Parameters\n",
        "n_step = 3\n",
        "n_hidden = 128\n",
        "\n",
        "def make_batch(seq_data):\n",
        "    input_batch, target_batch = [], []\n",
        "\n",
        "    for seq in seq_data:\n",
        "        input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n",
        "        target = word_dict[seq[-1]] # 'e' is target\n",
        "        input_batch.append(np.eye(n_class)[input])\n",
        "        target_batch.append(np.eye(n_class)[target])\n",
        "\n",
        "    return input_batch, target_batch\n",
        "\n",
        "# Model\n",
        "X = tf.placeholder(tf.float32, [None, n_step, n_class]) # [batch_size, n_step, n_class]\n",
        "Y = tf.placeholder(tf.float32, [None, n_class])         # [batch_size, n_class]\n",
        "\n",
        "W = tf.Variable(tf.random_normal([n_hidden, n_class]))\n",
        "b = tf.Variable(tf.random_normal([n_class]))\n",
        "\n",
        "cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)\n",
        "outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)\n",
        "\n",
        "# outputs : [batch_size, n_step, n_hidden]\n",
        "outputs = tf.transpose(outputs, [1, 0, 2]) # [n_step, batch_size, n_hidden]\n",
        "outputs = outputs[-1] # [batch_size, n_hidden]\n",
        "model = tf.matmul(outputs, W) + b # model : [batch_size, n_class]\n",
        "\n",
        "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model, labels=Y))\n",
        "optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)\n",
        "\n",
        "prediction = tf.cast(tf.argmax(model, 1), tf.int32)\n",
        "\n",
        "# Training\n",
        "init = tf.global_variables_initializer()\n",
        "sess = tf.Session()\n",
        "sess.run(init)\n",
        "\n",
        "input_batch, target_batch = make_batch(seq_data)\n",
        "\n",
        "for epoch in range(1000):\n",
        "    _, loss = sess.run([optimizer, cost], feed_dict={X: input_batch, Y: target_batch})\n",
        "    if (epoch + 1)%100 == 0:\n",
        "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "inputs = [sen[:3] for sen in seq_data]\n",
        "\n",
        "predict =  sess.run([prediction], feed_dict={X: input_batch})\n",
        "print(inputs, '->', [number_dict[n] for n in predict[0]])"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From <ipython-input-1-22a40a8beb7a>:38: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').\n",
            "Epoch: 0100 cost = 0.017741\n",
            "Epoch: 0200 cost = 0.004202\n",
            "Epoch: 0300 cost = 0.001867\n",
            "Epoch: 0400 cost = 0.001056\n",
            "Epoch: 0500 cost = 0.000681\n",
            "Epoch: 0600 cost = 0.000476\n",
            "Epoch: 0700 cost = 0.000352\n",
            "Epoch: 0800 cost = 0.000271\n",
            "Epoch: 0900 cost = 0.000215\n",
            "Epoch: 1000 cost = 0.000174\n",
            "['mak', 'nee', 'coa', 'wor', 'lov', 'hat', 'liv', 'hom', 'has', 'sta'] -> ['e', 'd', 'l', 'd', 'e', 'e', 'e', 'e', 'h', 'r']\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}