{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "STAT 479: Deep Learning (Spring 2019)  \n",
    "Instructor: Sebastian Raschka (sraschka@wisc.edu)  \n",
    "Course website: http://pages.stat.wisc.edu/~sraschka/teaching/stat479-ss2019/\n",
    "GitHub repository: https://github.com/rasbt/stat479-deep-learning-ss19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.1\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1\n",
      "pandas 0.24.0\n",
      "matplotlib 3.0.2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch,pandas,matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ADALINE with Stochastic Gradient Descent (Minibatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load & Prepare a Toy Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      x1   x2   x3   x4  y\n",
       "145  6.7  3.0  5.2  2.3  1\n",
       "146  6.3  2.5  5.0  1.9  1\n",
       "147  6.5  3.0  5.2  2.0  1\n",
       "148  6.2  3.4  5.4  2.3  1\n",
       "149  5.9  3.0  5.1  1.8  1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('./datasets/iris.data', index_col=None, header=None)\n",
    "df.columns = ['x1', 'x2', 'x3', 'x4', 'y']\n",
    "df = df.iloc[50:150]\n",
    "df['y'] = df['y'].apply(lambda x: 0 if x == 'Iris-versicolor' else 1)\n",
    "df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assign features and target\n",
    "\n",
    "X = torch.tensor(df[['x2', 'x3']].values, dtype=torch.float)\n",
    "y = torch.tensor(df['y'].values, dtype=torch.int)\n",
    "\n",
    "# Shuffling & train/test split\n",
    "\n",
    "torch.manual_seed(123)\n",
    "shuffle_idx = torch.randperm(y.size(0), dtype=torch.long)\n",
    "\n",
    "X, y = X[shuffle_idx], y[shuffle_idx]\n",
    "\n",
    "percent70 = int(shuffle_idx.size(0)*0.7)\n",
    "\n",
    "X_train, X_test = X[shuffle_idx[:percent70]], X[shuffle_idx[percent70:]]\n",
    "y_train, y_test = y[shuffle_idx[:percent70]], y[shuffle_idx[percent70:]]\n",
    "\n",
    "# Normalize (mean zero, unit variance)\n",
    "\n",
    "mu, sigma = X_train.mean(dim=0), X_train.std(dim=0)\n",
    "X_train = (X_train - mu) / sigma\n",
    "X_test = (X_test - mu) / sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFpFJREFUeJzt3X+MXXWZx/HP02FIJ6Gh0mli6W2ZUkgFp50SR0rTCBurKRh+iUpAA6JiXSOUFbfapgRqBcXUYGIw1pKqIdsFSsQRgXVE6y5ZshBby/ZHSi0l1s5AIg5bIOkQptNn/7gznU5nOnPvnHPv+Z7veb8SMtwzt9/7fM+9PNx+n/N8j7m7AADxmJR1AACAdJHYASAyJHYAiAyJHQAiQ2IHgMiQ2AEgMiR2AIgMiR0AIkNiB4DInJbFizY3N3tLS0sWLw0AubV9+/Z/uPv08Z6XSWJvaWnRtm3bsnhpAMgtMztYyfNYigGAyJDYASAyJHYAiEwma+yj6evrU1dXl959992sQwnG5MmTVSqV1NjYmHUoAHIkmMTe1dWlKVOmqKWlRWaWdTiZc3f19PSoq6tLc+bMyTocADkSzFLMu+++q2nTppHUB5iZpk2bxt9gAFQtmMQuiaR+Es4HgIkIKrEjZ3ZukX7YKq2dWv65c0vWEQEQiX1ca9eu1Q9+8IOajL19+3bNnz9f5513nlasWKFc3X925xbpNyuktw5J8vLP36wguQMBILFn6Ktf/ao2btyo/fv3a//+/frtb3+bdUiV+8M6qa93+LG+3vJxAJnKbWLv2NGtJfdv1ZxVT2vJ/VvVsaM78ZgPP/ywFixYoLa2Nt10000jfv/QQw/pwx/+sNra2vSpT31KR44ckSQ9/vjjam1tVVtbmy699FJJ0p49e3TxxRdr4cKFWrBggfbv3z9srNdff11vv/22Fi9eLDPTzTffrI6OjsRzqJu3uqo7DqBugrncsRodO7q1+old6u3rlyR1H+7V6id2SZKuvWjmhMbcs2eP7rvvPj3//PNqbm7Wm2++OeI51113nb785S9Lku666y5t2rRJt99+u9atW6fOzk7NnDlThw8fliRt2LBBd9xxhz73uc/pvffeU39//7Cxuru7VSqVjj8ulUrq7k7+P6e6ObM0sAwzynEAmcrlN/b1nfuOJ/VBvX39Wt+5b8Jjbt26VZ/+9KfV3NwsSTrrrLNGPGf37t36yEc+ovnz52vz5s3as2ePJGnJkiW65ZZb9NBDDx1P4IsXL9Z3v/tdff/739fBgwfV1NQ0bKzR1tNzdRXM0rulxuFzUmNT+TiATOUysb92uLeq45Vw93ET6y233KIHH3xQu3bt0j333HP8GvMNGzbo3nvv1aFDh7Rw4UL19PTos5/9rJ588kk1NTVp2bJl2rp167CxSqWSurqGli26urp09tlnTzj+ultwvXTVj6QzZ0my8s+rflQ+DiBTuUzsZ09tqup4JZYuXaotW7aop6dHkkZdinnnnXc0Y8YM9fX1afPmzcePHzhwQIsWLdK6devU3NysQ4cO6dVXX9W5556rFStW6Oqrr9bOnTuHjTVjxgxNmTJFL7zwgtxdDz/8sK655poJx5+JBddLX98trT1c/klSB4KQy8S+ctk8NTU2DDvW1NiglcvmTXjMD37wg1qzZo0uu+wytbW16c477xzxnO985ztatGiRPv7xj+sDH/jAUDwrV2r+/PlqbW3VpZdeqra2Nj322GNqbW3VwoUL9fLLL+vmm28eMd5PfvIT3XrrrTrvvPM0d+5cXXHFFROOHwAGWRbXTre3t/vJN9rYu3evLrjggorH6NjRrfWd+/Ta4V6dPbVJK5fNm3DhNGTVnhcA8TKz7e7ePt7zcnlVjFS++iXGRI4J2LmlfP38W13lq3KW3s2yEAott4kdkDTUATvYLDXYASuR3FFYuVxjB46jAxYYgcSOfKMDFhiBxI58O1WnKx2wKDASO/KNDlhgBBL7OGq5be+aNWs0a9YsnXHGGTUZvxDogAVG4KqYDF111VW67bbbdP7552cdSr4tuJ5EDpwgv9/Ya3D3nnpu2ytJl1xyiWbMmJE4bgA4UT6/sdfg2uV6b9sLALWS+Bu7mc0ysz+a2V4z22Nmd6QR2JhqcO1yvbftBYBaSWMp5qikb7j7BZIukfQ1M7swhXFPrQbXLtd7214AqJXEid3dX3f3Pw/8+zuS9kqq7SYuNbh2ud7b9gJAraRaPDWzFkkXSXoxzXFHqMG1y1ls2/vNb35TpVJJR44cUalU0tq1ayccPwAMSm3bXjM7Q9J/SbrP3Z8Y5ffLJS2XpNmzZ3/o4MGDw35f9fa0BdnRj217AQyq67a9ZtYo6ZeSNo+W1CXJ3TdK2iiV92NP/KJcuwwgdBl9AU2c2K1ccdwkaa+7P5A8JACIQIZbSqexxr5E0k2SPmpmLw3884mJDJTF3ZxCxvkAcizDLaUTf2N39/+WNPZ1ghWYPHmyenp6NG3atHEvOywCd1dPT48mT56cdSgAJiLDLaWD6TwtlUrq6urSG2+8kXUowZg8ebJKJbafBXLpzFJ5+WW04zUWTGJvbGzUnDlzsg4DRVSQK6xQZ0vvHr7GLtVtS+lgEjuQCe6ZiloZ/Pzk8aoYINfGKnCR2JFURpdl53fbXiAN3DMVESKxo9i4ZyoiRGJHsXHPVESIxI5i456piBDFU4B9hxAZvrEDQGRI7AAQGRI7kIadW6Qftkprp5Z/7tySdUQo8HvCGjuQFN2r4Sn4e8I3diCpDLdnxSkU/D0hsQNJ0b0anoK/JyR2ICm6V8NT8PeExA4kFVL3aoELhsOE9J5kgOIpkFSG27MOU/CC4TChvCcZsSzuq9ne3u7btm2r++sCUfth6ynu2DNL+vru+seD1JnZdndvH+95LMUAsSh4wRBDSOxALApeMMQQEjsQi4IXDDGExA7Egi2IMYCrYoCYsAUxxDd2AIgOiR0AIkNiLyo6FIFoscZeRHQoAlHjG3sRFXxLUyB2JPYiokMRiFoqid3MfmZmfzczNqTIAzoUgail9Y39F5IuT2ks1BodivFKWhR/6k7p22dJa88s/3zqztrEiZpKpXjq7s+ZWUsaY6EOCr6labSSFsWfulPatmnosfcPPb7ygXRjRU2ltm3vQGJ/yt1bx3su2/YCNZB0295vn1VO5iezBumeN5PHh8SC27bXzJab2TYz2/bGG2/U62WB4khaFB8tqY91HMGqW2J3943u3u7u7dOnT6/XywLFkbQobg3VHUewuNwRExdK92rSgl8o80gaR9Ki+Iduqe74qYRyPgssleKpmT0i6Z8kNZtZl6R73H3T2H8KuRZK92rSgl8o80gjjqRF8cHztf0X5fNoDeWkXk3hNJTzWXDc8xQTE8r9NZMW/EKZRyhxJBXLPAIVXPEUkQmlezVpwS+UeYQSR1KxzCPnSOyYmFC6V5MW/EKZRyhxJBXLPHKOxI6JCaV7NWnBL615JC3gLr1bajh9+LGG0/PXDRzK56LgSOyYmFDur3nlA1L7l4a+oVtD+XGlBb805jFYwB1c/hks4Fab3E+ud2VQ/0oslM9FwVE8BZJKo2OToiMqQPEUqJc0OjYpOiJFJHYgqTQ6Nik6IkUkdiCpNDo2l94tTWocfmxSYz6LjnSeZo57ngJJpdGxKUlmYz/OAzpPg0DxFAhBLMXTWOYRKIqnQJ7EUjyNZR45R2IHQhBL8TSWeeQciR3ZotBWllbHZtbnk87TIFA8RXYotA1J4z60IZxP7qcbBIqnyA6FtnRxPqNH8RTho9CWLs4nBpDYkR0KbenifGIAiR3ZianQlnXRUorrfCIRiqfITiyFthCKlie+Vt7PJxKjeAokRdESdULxFKgXipYIDIkdSIqiJQJDYkemDvz8Kzq69n3ye87U0bXv04Gff6X6QbIuXFK0RGBI7MjMgZ9/Ref+9VGdpmMyk07TMZ3710erS+6Dhcu3DknyocJlPZM79/lEYCieIjNH175Pp+nYyOOapNPW/l9lg1C4RIFQPEXwGnxkUh/r+KgoXAIjkNiRmX4b/eN3quOjonAJjEBir1bWhbq0BDCPg+dcr5NXAt3LxysWSOGyY0e3lty/VXNWPa0l929Vx47uur4+cCISezVCKNSlIZB5zP3CT/Vqyw06qklyL6+tv9pyg+Z+4aeVDxJA4bJjR7dWP7FL3Yd75ZK6D/dq9RO7SO7IDMXTasRSqItlHoFYcv9WdR/uHXF85tQmPb/qoxlEhFjVtXhqZpeb2T4ze8XMVqUxZpBiKdTFMo9AvDZKUh/rOFBriRO7mTVI+rGkKyRdKOlGM7sw6bhBiqVQF8s8AnH21KaqjgO1lsY39oslveLur7r7e5IelXRNCuOGJ5BCXWJL79bRhsnDDh1tmJy/eQzKuBC8ctk8NTU2DDvW1NiglcvmVTUOBVikJY3EPlPSiQu2XQPH4hNAoS4NHf1LtKrvVnUda9YxN3Uda9aqvlvV0b8k69CqF0Ah+NqLZup7183XzKlNMpXX1r933Xxde1Hl/xlQgEWaEhdPzewzkpa5+60Dj2+SdLG7337S85ZLWi5Js2fP/tDBgwcTvS4mLqpiXySF4KjeE9RMPYunXZJmnfC4JOm1k5/k7hvdvd3d26dPn57Cy2Kioir2RVIIjuo9QebSSOx/knS+mc0xs9Ml3SDpyRTGRY1EVeyLpBAc1XuCzCVO7O5+VNJtkjol7ZW0xd33JB0XtZNWsS8IS++WGk4ffqzh9PoXghMWcCnAIk2p3PPU3Z+R9EwaY6H2Bot66zv36bXDvTp7apNWLptXVbEvKKPtS1BPKdzzNI33ZLAA29vXL2moAHvi+CgGOk+RbyEUT0OIQRRgi4Bte1EMIRRPQ4hBFGAxhMSOfAuheBpCDKIAiyEkduRbCF20KXUkJy18rlw2T42TbHgYk6yqAizF1ziQ2JFrQXTRptCRnFrnqY3zuB4xIHMUT5FrsRQM05hH0jFiOZcxo3iKQoilYJjGPJKOEcu5BIkdORdLwTCNeSQdI5ZzCRI7cm7lsnlqbDipYNiQv4JhGp2nSceIqiO54FLpPAUydXKZqIqyUSjdmml0niYdI7qO5AKjeIpco2CIIqF4ikKgYAiMRGJHrlEwBEZijb1KHTu641iD3LlF+sO68n4mZ5bKXZI5u8WfVC74rXz8f9V3bGhJsZpuy6R/ftBdHbv0yIuH1O+uBjPduGiW7r12flVjpCGEz2cIMRQdib0KoRTaEkthm9mgJOi2TOPP39WxS//2wt+OP+53P/64nsk9hM9nCDGApZiqrO/cd/wDO6i3r1/rO/dlFNEE/WHdUFIf1NdbPp4z6zv3qa9/+AUAff1e8XuS9M9L0iMvjrJl7xjHayWEz2cIMYDEXpVoCm2BbDObhhCKp/2nuLLsVMdrJYTPZwgxgMRelWgKbYFsM5uGEIqnDTb62s2pjtdKCJ/PEGIAib0q0dyXMpBtZtMQQrfljYtmVXW8VkLoHA0hBlA8rUo096UcLJAmuComiHkojG7L9nPO0r+/8DcdO+HYpIHj9RRC52gIMYDO07qLpdMxlnmkgXOBeqHzNFCxFJdimUcaOBcIDYm9zmIpLsUyjzRwLhAaEnudxVJcimUeUrnBaO7qZ9Sy6mnNXf2M7urYVdWfj+lcIA4UT+ssluJSLPNIo2s0lnOBeFA8RaHNXf3MqI1EDWY68L1PZBARcGoUT4EKhNI1CqSJxI5CC6VrFEgTiR2FllbXaAhduMAgEjsKrf2cs9Qwafi384ZJVlXX6GAXbvfhXrmGunBJ7sgKiR2Ftr5zn/qPDV9P7z9W3ba9bFWL0CRK7Gb2GTPbY2bHzGzcSi0QmjS6Ruk8RWiSfmPfLek6Sc+lEAtQd2l0jdJ5itAkSuzuvtfd+fsmJiyNomOSMdLoGl25bJ4aG4av0zc2VH/fVCAtdes8NbPlkpZL0uzZs+v1sghYGlv/Jh0jta7Rky975zJ4ZGjczlMz+72k94/yqzXu/uuB5/ynpH9194raSek8hZTOdrchbJkbQgwohko7T8f9xu7uH0snJGC4WAqXIcQAnIjLHZGZWAqXIcQAnCjp5Y6fNLMuSYslPW1mnemEFa6si31pCWEeaRUus94yN4QYgBMlKp66+68k/SqlWIIXQrEvDaHMI43CZQhb5oYQA3Aitu2tAsW+dMcAUB227a0Bin3pjgGgNkjsVaDYl+4YAGqDxF4Fin3pjgGgNrjnaRUo9qU7BoDaoHgKADlB8RQACorEDgCRIbEDQGRI7AAQGRI7AESGxA4AkSGxA0BkaFDKqY4d3TQHARgViT2HQtj6F0C4WIrJofWd+44n9UG9ff1a37kvo4gAhITEnkNsmQtgLCT2HGLLXABjIbHnEFvmAhgLxdMcYstcAGMhsefUtRfNJJEDGBVLMQAQGRI7AEQmV0sxdFsCwPhyk9jptgSAyuRmKYZuSwCoTG4SO92WAFCZ3CR2ui0BoDK5Sex0WwJAZXJTPKXbEgAqkyixm9l6SVdJek/SAUlfcPfDaQQ2GrotAWB8SZdinpXU6u4LJP1F0urkIQEAkkiU2N39d+5+dODhC5JKyUMCACSR5hr7FyU9dqpfmtlyScslafbs2Sm+bDHRhQvgVMZN7Gb2e0nvH+VXa9z91wPPWSPpqKTNpxrH3TdK2ihJ7e3tPqFoIYkuXABjGzexu/vHxvq9mX1e0pWSlro7CbsOxurCJbEDSHpVzOWSviXpMnc/kk5IGA9duADGkvSqmAclTZH0rJm9ZGYbUogJ46ALF8BYkl4Vc567z3L3hQP//HNageHU6MIFMJbcdJ5iCF24AMZCYs8punABnEpuNgEDAFSGxA4AkSGxA0BkSOwAEBkSOwBEhsQOAJEhsQNAZLiOvaDY9heIF4m9gNj2F4gbSzEFNNa2vwDyj8ReQGz7C8SNxF5AbPsLxI3EXkBs+wvEjeJpAbHtLxA3EntBse0vEC+WYgAgMiR2AIgMiR0AIkNiB4DIkNgBIDIkdgCIjLl7/V/U7A1JB+v+whPTLOkfWQdRJ0WZa1HmKRVnrkWZ5znuPn28J2WS2PPEzLa5e3vWcdRDUeZalHlKxZlrUeZZKZZiACAyJHYAiAyJfXwbsw6gjooy16LMUyrOXIsyz4qwxg4AkeEbOwBEhsReATNbb2Yvm9lOM/uVmU3NOqZaMbPPmNkeMztmZtFdZWBml5vZPjN7xcxWZR1PrZjZz8zs72a2O+tYasnMZpnZH81s78Dn9o6sYwoBib0yz0pqdfcFkv4iaXXG8dTSbknXSXou60DSZmYNkn4s6QpJF0q60cwuzDaqmvmFpMuzDqIOjkr6hrtfIOkSSV+L+D2tGIm9Au7+O3c/OvDwBUmlLOOpJXff6+6x3tX6YkmvuPur7v6epEclXZNxTDXh7s9JejPrOGrN3V939z8P/Ps7kvZKKvyNBkjs1fuipP/IOghMyExJh0543CWSQDTMrEXSRZJezDaS7HEHpQFm9ntJ7x/lV2vc/dcDz1mj8l/9NtcztrRVMtdI2SjHuCwsAmZ2hqRfSvoXd38763iyRmIf4O4fG+v3ZvZ5SVdKWuo5v0Z0vLlGrEvSrBMelyS9llEsSImZNaqc1De7+xNZxxMClmIqYGaXS/qWpKvd/UjW8WDC/iTpfDObY2anS7pB0pMZx4QEzMwkbZK0190fyDqeUJDYK/OgpCmSnjWzl8xsQ9YB1YqZfdLMuiQtlvS0mXVmHVNaBgrgt0nqVLnItsXd92QbVW2Y2SOS/kfSPDPrMrMvZR1TjSyRdJOkjw78t/mSmX0i66CyRucpAESGb+wAEBkSOwBEhsQOAJEhsQNAZEjsABAZEjsARIbEDgCRIbEDQGT+HzaMz5IBVbXSAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0')\n",
    "plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEqtJREFUeJzt3X+MXWWdx/HP13FMJ6FpA9PEtne6rdZUcH6UeKU0jZhYTcFYwKqNYGC7qzRLVos/ti6kBLoVFVMDCcHYlGAMsYuOEUcMrrNq3ZA1gk4p23a2dgsktXcgsQ4pkHQIw/DdP+70x7TTzsw5597znOe8XwkZ7unlOd9zQz89fZ7vea65uwAA8Xhb3gUAALJFsANAZAh2AIgMwQ4AkSHYASAyBDsARIZgB4DIEOwAEBmCHQAi8/Y8Ttre3u6LFy/O49QAUFh79uz5m7vPm+p9uQT74sWLNTAwkMepAaCwzOzIdN7HVAwARIZgB4DIEOwAEJlc5tgnMzo6qlqtptdffz3vUoIxa9YsVSoVtba25l0KgAIJJthrtZpmz56txYsXy8zyLid37q7h4WHVajUtWbIk73IAFEgwUzGvv/66LrnkEkJ9nJnpkksu4W8wAGYsmGCXRKifhc8DQBJBBTsARGdfr3R/p7R1bv3nvt6Gn5Jgn8LWrVv1ne98pyFj79mzR11dXVq6dKk2bdokvn8WiMy+XukXm6RXjkry+s9fbGp4uBPsObr11lu1c+dOHT58WIcPH9avfvWrvEsCkKXfbpNGRyYeGx2pH2+gwgZ7394hrbp3t5bc/oRW3btbfXuHUo/5yCOPqLu7Wz09PbrpppvO+fWHHnpIH/jAB9TT06NPfvKTOnHihCTpJz/5iTo7O9XT06OrrrpKkjQ4OKgrrrhCy5cvV3d3tw4fPjxhrJdeekmvvvqqVq5cKTPTzTffrL6+vtTXACAgr9RmdjwjwbQ7zkTf3iHd8dh+jYyOSZKGjo/ojsf2S5Kuv3xhojEHBwf1jW98Q7///e/V3t6ul19++Zz3rFu3Trfccosk6c4779TDDz+sL37xi9q2bZv6+/u1cOFCHT9+XJK0Y8cO3XbbbfrsZz+rN954Q2NjYxPGGhoaUqVSOfW6UqloaCj9H04AAjKnMj4NM8nxBirkHfv2/kOnQv2kkdExbe8/lHjM3bt361Of+pTa29slSRdffPE57zlw4IA++MEPqqurS7t27dLg4KAkadWqVdqwYYMeeuihUwG+cuVKffOb39S3v/1tHTlyRG1tbRPGmmw+nS6YBshh4Qo4ZfVdUuvE3/tqbasfb6BCBvuLx0dmdHw63H3KYN2wYYMefPBB7d+/X3ffffepHvMdO3bonnvu0dGjR7V8+XINDw/rxhtv1OOPP662tjatWbNGu3fvnjBWpVJRrXb6r2O1Wk0LFixIXD8mkdPCFXBK93pp7QPSnA5JVv+59oH68QYqZLAvmNs2o+PTsXr1avX29mp4eFiSJp2Kee211zR//nyNjo5q165dp44///zzWrFihbZt26b29nYdPXpUL7zwgt71rndp06ZNuvbaa7Vv374JY82fP1+zZ8/WU089JXfXI488ouuuuy5x/ZhETgtXwATd66UvH5C2Hq//bHCoSwUN9s1rlqmttWXCsbbWFm1esyzxmO973/u0ZcsWfehDH1JPT4++8pWvnPOer3/961qxYoU++tGP6r3vfe/pejZvVldXlzo7O3XVVVepp6dHP/7xj9XZ2anly5frz3/+s26++eZzxvve976nz3/+81q6dKne/e5365prrklcPyaR08IVkDfLo3e6Wq362V+0cfDgQV166aXTHqNv75C29x/Si8dHtGBumzavWZZ44TRkM/1ccIb7O8+zcNVRv3MCCsbM9rh7dar3FbIrRqp3v8QY5MjQ6rvqc+pnTsc0YeEKyFvqqRgz6zCz35nZQTMbNLPbsigMSC2nhSsgb1ncsb8p6avu/oyZzZa0x8x+7e7/m8HYQDrd6wlylE7qO3Z3f8ndnxn/99ckHZTEHAkA5CTTrhgzWyzpcklPZzkuAGD6Mgt2M7tI0k8lfcndX53k1zea2YCZDRw7diyr0wIAzpJJsJtZq+qhvsvdH5vsPe6+092r7l6dN29eFqdtikZu27tlyxZ1dHTooosuasj4AMopi64Yk/SwpIPufl/6kspj7dq1+uMf/5h3GQAik8Ud+ypJN0n6sJk9O/7PxzIY98IasLlTM7ftlaQrr7xS8+fPT103AJwpdbuju/+3pOZuS3hyc6eTD56c3NxJStza1uxtewGgUQq5V0wjNndq9ra9ANAoxQz2Bmzu1OxtewGgUYoZ7Of79pEU30rS7G17AaBRihnsDfhWkjy27f3a176mSqWiEydOqFKpaOvWrYnrB4CTCrttr/b11ufUX6nV79RX3xXlniBs2wvgpOi37WVzJwCYXDGnYgAA5xVUsOcxLRQyPg8ASQQT7LNmzdLw8DBhNs7dNTw8rFmzZuVdCoCCCWaOvVKpqFariZ0fT5s1a5YqleQtnADKKZhgb21t1ZIlS/IuAwAKL5ipGABANgh2AIgMwQ4AkSHYkU4D9sUHkE4wi6cooAbsiw8gPe7YkVwD9sUHkB7BjuQasC8+gPQIdiTXgH3xAaRHsCO5BuyLDyA9gh3Jda+X1j4gzemQZPWfax9g4TRPdClBdMUgLfbFDwddShjHHTsQC7qUMI5gB2JBlxLGEexALOhSwjiCHYgFXUoYR7AjbqF3iWRZH11KGEdXDOIVepdII+qjSwnijh0xC71LJPT6UFgEO+IVepdI6PWhsAh2xCv0LpHQ60NhEeyI1+q7dO7/4m8Lp0uELhY0CMGOeP3lKUlvnXXwrfHjAaCLBQ1i7t70k1arVR8YGGj6eVEy/3ax5GPnHrcW6e6Xm18PkJKZ7XH36lTv444d8Zos1C90HIgEwY54WcvMjgORyCTYzez7ZvZXMzuQxXhAJt6/YWbHgUhkdcf+A0lXZzQWkI2P3ydVP3f6Dt1a6q8/fl++dQENlsmWAu7+pJktzmIsIFMfv48gR+kwxw4AkWlasJvZRjMbMLOBY8eONeu0AFA6TQt2d9/p7lV3r86bN69ZpwWA0mEqBgAik1W746OS/iBpmZnVzOxzWYwLAJi5rLpibshiHABAekzFAEBkCHYAiAzBDgCRIdgBIDIEOwBEhmAHgMgQ7AAQGYIdACJDsANAZAh2AIgMwQ4AkSHYASAyBDsARIZgB4DIEOwAEBmCHQAiQ7ADQGQIdgCIDMEOAJEh2AEgMgQ7AESGYEfc9vVK93dKW+fWf+7rzbsioOHenncBQMPs65V+sUkaHam/fuVo/bUkda/Pry6gwbhjR7x+u+10qJ80OlI/DkSMYEe8XqnN7DgQCYId8ZpTmdlxIBIEO+K1+i6ptW3isda2+nEgYgQ70gm566R7vdRzo2Qt9dfWUn+dZuE06+sN+fNDYdEVg+RC7zrZ1yv9z79LPlZ/7WP114uuTFZf1tcb+ueHwuKOHcmF3nWSdX2hjweMI9iRXOhdJ1nXF/p4wDiCHcmF3nWSdX2hjweMI9izVLaFsNC7TrKuL/TxgHEEe1ZOLoS9clSSn14Iizncu9dLax+Q5nRIsvrPtQ+Es/CXdX2hjweMM3dv+kmr1aoPDAw0/bwNdX/neKifZU6H9OUDza8HQHTMbI+7V6d6H3fsWWEhDEAgMgl2M7vazA6Z2XNmdnsWYxYOC2EAApE62M2sRdJ3JV0j6TJJN5jZZWnHLRwWwgAEIos79iskPefuL7j7G5J+JOm6DMYtliIshJWta0cq5zWj9LLYUmChpDNXDWuSVmQwbvF0rw8ryM9UxsfXy3jNgLK5Y7dJjp3TamNmG81swMwGjh07lsFpMSNlfHy9jNcMKJtgr0nqOON1RdKLZ7/J3Xe6e9Xdq/PmzcvgtJiRMnbtlPGaAWUT7H+S9B4zW2Jm75D0GUmPZzAuslTGrp0yXjOgDILd3d+U9AVJ/ZIOSup198G04yJjZezaKeM1A8poP3Z3/6WkX2YxFhrk5GLhb7fVpyLmVOoBl3IRsW/vkLb3H9KLx0e0YG6bNq9ZpusvX5hBwRlo0DUDoWNLASTWt3dIdzy2XyOjY6eOtbW26FvrusIJdyAibCmAhtvef2hCqEvSyOiYtvcfyqkiABLBjhRePD4yo+MAmoNgR2IL5rbN6DiA5iDYkdjmNcvU1toy4Vhba4s2r1mWU0UApIy6YlBOJxdIs+yKCbrLBigIgh2pXH/5wsyC9+wum6HjI7rjsf2nzgNgepiKQTDosgGyQbAjGHTZANkg2BEMumyAbBDsCEYjumz69g5p1b27teT2J7Tq3t3q2zuUtkwgeCyeIhhZd9mwGIuyItgRlCy7bC60GEuwI2ZMxSBaLMairAh2RIvFWJQVwY5oseUByoo5dkSrEVseAEVAsCNqWS7GAkXBVAwARIZgB4DIEOwAEBmCHQAiw+IpMAN8EQiKgGAHpom9Z1AUTMUA08QXgaAoCHZgmth7BkVBsAPTxN4zKAqCHZgm9p5BUbB4CkwTe8+gKAh2YAbYewZFwFQMAESGYAeAyBDsABAZ5thLpIyPw5fxmgGCvSTK+Dh8Ga8ZkJiKKY0yPg5fxmsGpJTBbmafNrNBM3vLzKpZFYXslfFx+DJeMyClv2M/IGmdpCczqAUNVMbH4ct4zYCUMtjd/aC78/faAijj4/BlvGZAYvG0NMr4OHwZrxmQJHP3C7/B7DeS3jnJL21x95+Pv+e/JP2Luw9cYJyNkjZK0qJFi95/5MiRpDUDQCmZ2R53n3I9c8o7dnf/SBYFuftOSTslqVqtXvhPEwBAYrQ7AkBk0rY7fsLMapJWSnrCzPqzKQsAkFSqxVN3/5mkn2VUCwAgA3TFIBX2YgHCQ7AjMfZiAcLE4ikSYy8WIEwEOxJjLxYgTAQ7EmMvFiBMBDsSYy8WIEwsnmYo9A6RO/v269Gnj2rMXS1mumFFh+65vivxeOzFAoSJYM9I6B0id/bt1w+f+sup12Pup16nDfcQrg/AaUzFZCT0DpFHnz46o+MAiotgz0joHSJj59nF83zHARQXwZ6R0DtEWsxmdBxAcRHsGQm9Q+SGFR0zOg6guEq9eJplF0voHSInF0iz7IoBEKYpv0GpEarVqg8MnPfLlpri7C4WqX6H/a11XcGEMQCcabrfoFTaqZjQu1gAIKnSBnvoXSwAkFRpgz30LhYASKq0wR56FwsAJFXarpjQu1gAIKnSBrvEPicA4lTaqRgAiBXBDgCRIdgBIDKlnmNHeEL/shKgCAh2BCP0LysBioKpGASDbR6AbBDsCAbbPADZINgRDLZ5ALJBsCMYbPMAZIPF04CVrUOEbR6AbBDsgSprhwjbPADpMRUTKDpEACRFsAeKDhEASRHsgaJDBEBSBHug6BABkBSLp4GiQwRAUqmC3cy2S1or6Q1Jz0v6B3c/nkVhoEMEQDJpp2J+LanT3bsl/Z+kO9KXBABII1Wwu/t/uvub4y+fklRJXxIAII0sF0//UdJ/ZDgeACCBKefYzew3kt45yS9tcfefj79ni6Q3Je26wDgbJW2UpEWLFiUqtmyP2ANAEubu6QYw+3tJ/yRptbufmM5/U61WfWBgYEbnOfsRe6ne/vetdV2EO4BSMLM97l6d6n2ppmLM7GpJ/yrp2umGelI8Yg8A05N2jv1BSbMl/drMnjWzHRnUNCkesQeA6UnVx+7uS7MqZCoL5rZpaJIQ5xF7AJioMFsK8Ig9AExPYbYU4BF7AJiewgS7xCP2ADAdhZmKAQBMD8EOAJEh2AEgMgQ7AESGYAeAyBDsABCZ1JuAJTqp2TFJR5p+4ny0S/pb3kUEiM9lcnwu58dnI/2du8+b6k25BHuZmNnAdHZjKxs+l8nxuZwfn830MRUDAJEh2AEgMgR74+3Mu4BA8blMjs/l/Phspok5dgCIDHfsABAZgr0JzOzTZjZoZm+ZWelX9c3sajM7ZGbPmdntedcTAjP7vpn91cwO5F1LSMysw8x+Z2YHx38P3ZZ3TUVAsDfHAUnrJD2ZdyF5M7MWSd+VdI2kyyTdYGaX5VtVEH4g6eq8iwjQm5K+6u6XSrpS0j/z/8vUCPYmcPeD7s63btddIek5d3/B3d+Q9CNJ1+VcU+7c/UlJL+ddR2jc/SV3f2b831+TdFASX8owBYIdzbZQ0tEzXtfEb1RMg5ktlnS5pKfzrSR8hfoGpZCZ2W8kvXOSX9ri7j9vdj0Bs0mO0ZqFCzKziyT9VNKX3P3VvOsJHcGeEXf/SN41FERNUscZryuSXsypFhSAmbWqHuq73P2xvOspAqZi0Gx/kvQeM1tiZu+Q9BlJj+dcEwJlZibpYUkH3f2+vOspCoK9CczsE2ZWk7RS0hNm1p93TXlx9zclfUFSv+oLYb3uPphvVfkzs0cl/UHSMjOrmdnn8q4pEKsk3STpw2b27Pg/H8u7qNDx5CkARIY7dgCIDMEOAJEh2AEgMgQ7AESGYAeAyBDsABAZgh0AIkOwA0Bk/h9KNAHoCqXDewAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], label='class 0')\n",
    "plt.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], label='class 1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement ADALINE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Adaline1():\n",
    "    def __init__(self, num_features):\n",
    "        self.num_features = num_features\n",
    "        self.weights = torch.zeros(num_features, 1, \n",
    "                                   dtype=torch.float)\n",
    "        self.bias = torch.zeros(1, dtype=torch.float)\n",
    "\n",
    "    def forward(self, x):\n",
    "        netinputs = torch.add(torch.mm(x, self.weights), self.bias)\n",
    "        activations = netinputs\n",
    "        return activations.view(-1)\n",
    "        \n",
    "    def backward(self, x, yhat, y):  \n",
    "        \n",
    "        grad_loss_yhat = y - yhat\n",
    "        \n",
    "        grad_yhat_weights = x\n",
    "        grad_yhat_bias = 1.\n",
    "        \n",
    "        # Chain rule: inner times outer\n",
    "        grad_loss_weights = 2* -torch.mm(grad_yhat_weights.t(),\n",
    "                                         grad_loss_yhat.view(-1, 1)) / y.size(0)\n",
    "\n",
    "        grad_loss_bias = 2* -torch.sum(grad_yhat_bias*grad_loss_yhat) / y.size(0)\n",
    "        \n",
    "        return (-1)*grad_loss_weights, (-1)*grad_loss_bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Training and Evaluation Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################\n",
    "##### Training and evaluation wrappers\n",
    "###################################################\n",
    "\n",
    "def loss(yhat, y):\n",
    "    return torch.mean((yhat - y)**2)\n",
    "\n",
    "\n",
    "def train(model, x, y, num_epochs,\n",
    "          learning_rate=0.01, seed=123, minibatch_size=10):\n",
    "    cost = []\n",
    "    \n",
    "    torch.manual_seed(seed)\n",
    "    for e in range(num_epochs):\n",
    "        \n",
    "        #### Shuffle epoch\n",
    "        shuffle_idx = torch.randperm(y.size(0), dtype=torch.long)\n",
    "        minibatches = torch.split(shuffle_idx, minibatch_size)\n",
    "        \n",
    "        for minibatch_idx in minibatches:\n",
    "\n",
    "            #### Compute outputs ####\n",
    "            yhat = model.forward(x[minibatch_idx])\n",
    "\n",
    "            #### Compute gradients ####\n",
    "            negative_grad_w, negative_grad_b = \\\n",
    "                model.backward(x[minibatch_idx], yhat, y[minibatch_idx])\n",
    "\n",
    "            #### Update weights ####\n",
    "            model.weights += learning_rate * negative_grad_w\n",
    "            model.bias += learning_rate * negative_grad_b\n",
    "            \n",
    "            #### Logging ####\n",
    "            minibatch_loss = loss(yhat, y[minibatch_idx])\n",
    "            print('    Minibatch MSE: %.3f' % minibatch_loss)\n",
    "\n",
    "        #### Logging ####\n",
    "        yhat = model.forward(x)\n",
    "        curr_loss = loss(yhat, y)\n",
    "        print('Epoch: %03d' % (e+1), end=\"\")\n",
    "        print(' | MSE: %.5f' % curr_loss)\n",
    "        cost.append(curr_loss)\n",
    "\n",
    "    return cost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Linear Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    Minibatch MSE: 0.500\n",
      "    Minibatch MSE: 0.341\n",
      "    Minibatch MSE: 0.220\n",
      "    Minibatch MSE: 0.245\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.133\n",
      "    Minibatch MSE: 0.144\n",
      "Epoch: 001 | MSE: 0.12142\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.147\n",
      "    Minibatch MSE: 0.064\n",
      "    Minibatch MSE: 0.079\n",
      "    Minibatch MSE: 0.185\n",
      "    Minibatch MSE: 0.063\n",
      "    Minibatch MSE: 0.135\n",
      "Epoch: 002 | MSE: 0.09932\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.064\n",
      "    Minibatch MSE: 0.128\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.079\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.080\n",
      "Epoch: 003 | MSE: 0.09693\n",
      "    Minibatch MSE: 0.131\n",
      "    Minibatch MSE: 0.146\n",
      "    Minibatch MSE: 0.050\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.106\n",
      "    Minibatch MSE: 0.072\n",
      "    Minibatch MSE: 0.102\n",
      "Epoch: 004 | MSE: 0.09658\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.204\n",
      "    Minibatch MSE: 0.149\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.060\n",
      "    Minibatch MSE: 0.056\n",
      "    Minibatch MSE: 0.069\n",
      "Epoch: 005 | MSE: 0.09657\n",
      "    Minibatch MSE: 0.068\n",
      "    Minibatch MSE: 0.111\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.115\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.074\n",
      "    Minibatch MSE: 0.087\n",
      "Epoch: 006 | MSE: 0.09650\n",
      "    Minibatch MSE: 0.057\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.133\n",
      "    Minibatch MSE: 0.127\n",
      "    Minibatch MSE: 0.062\n",
      "    Minibatch MSE: 0.153\n",
      "    Minibatch MSE: 0.103\n",
      "Epoch: 007 | MSE: 0.09683\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.110\n",
      "    Minibatch MSE: 0.101\n",
      "    Minibatch MSE: 0.065\n",
      "    Minibatch MSE: 0.126\n",
      "    Minibatch MSE: 0.124\n",
      "    Minibatch MSE: 0.076\n",
      "Epoch: 008 | MSE: 0.09648\n",
      "    Minibatch MSE: 0.120\n",
      "    Minibatch MSE: 0.056\n",
      "    Minibatch MSE: 0.100\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.106\n",
      "    Minibatch MSE: 0.075\n",
      "    Minibatch MSE: 0.144\n",
      "Epoch: 009 | MSE: 0.09740\n",
      "    Minibatch MSE: 0.073\n",
      "    Minibatch MSE: 0.071\n",
      "    Minibatch MSE: 0.084\n",
      "    Minibatch MSE: 0.152\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.108\n",
      "    Minibatch MSE: 0.118\n",
      "Epoch: 010 | MSE: 0.09636\n",
      "    Minibatch MSE: 0.058\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.145\n",
      "    Minibatch MSE: 0.081\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.127\n",
      "    Minibatch MSE: 0.115\n",
      "Epoch: 011 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.123\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.085\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.143\n",
      "    Minibatch MSE: 0.081\n",
      "Epoch: 012 | MSE: 0.09718\n",
      "    Minibatch MSE: 0.096\n",
      "    Minibatch MSE: 0.076\n",
      "    Minibatch MSE: 0.149\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.116\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.091\n",
      "Epoch: 013 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.104\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.120\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.045\n",
      "    Minibatch MSE: 0.124\n",
      "Epoch: 014 | MSE: 0.09685\n",
      "    Minibatch MSE: 0.121\n",
      "    Minibatch MSE: 0.051\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.122\n",
      "    Minibatch MSE: 0.030\n",
      "    Minibatch MSE: 0.158\n",
      "    Minibatch MSE: 0.121\n",
      "Epoch: 015 | MSE: 0.09745\n",
      "    Minibatch MSE: 0.080\n",
      "    Minibatch MSE: 0.119\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.044\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.180\n",
      "Epoch: 016 | MSE: 0.09693\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.075\n",
      "    Minibatch MSE: 0.184\n",
      "    Minibatch MSE: 0.105\n",
      "    Minibatch MSE: 0.121\n",
      "    Minibatch MSE: 0.066\n",
      "    Minibatch MSE: 0.096\n",
      "Epoch: 017 | MSE: 0.09699\n",
      "    Minibatch MSE: 0.076\n",
      "    Minibatch MSE: 0.050\n",
      "    Minibatch MSE: 0.198\n",
      "    Minibatch MSE: 0.105\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.136\n",
      "    Minibatch MSE: 0.099\n",
      "Epoch: 018 | MSE: 0.09672\n",
      "    Minibatch MSE: 0.158\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.087\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.103\n",
      "    Minibatch MSE: 0.112\n",
      "    Minibatch MSE: 0.081\n",
      "Epoch: 019 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.111\n",
      "    Minibatch MSE: 0.147\n",
      "    Minibatch MSE: 0.083\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.065\n",
      "Epoch: 020 | MSE: 0.09635\n"
     ]
    }
   ],
   "source": [
    "model = Adaline1(num_features=X_train.size(1))\n",
    "cost = train(model, \n",
    "             X_train, y_train.float(),\n",
    "             num_epochs=20,\n",
    "             learning_rate=0.1,\n",
    "             seed=123,\n",
    "             minibatch_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate ADALINE Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Loss (MSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(range(len(cost)), cost)\n",
    "plt.ylabel('Mean Squared Error')\n",
    "plt.xlabel('Epoch')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare with analytical solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights tensor([[-0.0763],\n",
      "        [ 0.4181]])\n",
      "Bias tensor([0.4888])\n"
     ]
    }
   ],
   "source": [
    "print('Weights', model.weights)\n",
    "print('Bias', model.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analytical weights tensor([[-0.0703],\n",
      "        [ 0.4219]])\n",
      "Analytical bias tensor([0.4857])\n"
     ]
    }
   ],
   "source": [
    "def analytical_solution(x, y):\n",
    "    Xb = torch.cat( (torch.ones((x.size(0), 1)), x), dim=1)\n",
    "    w = torch.zeros(x.size(1))\n",
    "    z = torch.inverse(torch.matmul(Xb.t(), Xb))\n",
    "    params = torch.matmul(z, torch.matmul(Xb.t(), y))\n",
    "    b, w = torch.tensor([params[0]]), params[1:].view(x.size(1), 1)\n",
    "    return w, b\n",
    "\n",
    "w, b = analytical_solution(X_train, y_train.float())\n",
    "print('Analytical weights', w)\n",
    "print('Analytical bias', b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate on Evaluation Metric (Prediction Accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_where(cond, x_1, x_2):\n",
    "    return (cond * x_1) + ((1-cond) * x_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Accuracy: 90.00\n",
      "Test Accuracy: 96.67\n"
     ]
    }
   ],
   "source": [
    "train_pred = model.forward(X_train)\n",
    "train_acc = torch.mean(\n",
    "    (custom_where(train_pred > 0.5, 1, 0).int() == y_train).float())\n",
    "\n",
    "test_pred = model.forward(X_test)\n",
    "test_acc = torch.mean(\n",
    "    (custom_where(test_pred > 0.5, 1, 0).int() == y_test).float())\n",
    "\n",
    "print('Training Accuracy: %.2f' % (train_acc*100))\n",
    "print('Test Accuracy: %.2f' % (test_acc*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision Boundary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x216 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "##########################\n",
    "### 2D Decision Boundary\n",
    "##########################\n",
    "\n",
    "w, b = model.weights, model.bias - 0.5\n",
    "\n",
    "x_min = -3\n",
    "y_min = ( (-(w[0] * x_min) - b[0]) \n",
    "          / w[1] )\n",
    "\n",
    "x_max = 3\n",
    "y_max = ( (-(w[0] * x_max) - b[0]) \n",
    "          / w[1] )\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n",
    "\n",
    "ax[0].plot([x_min, x_max], [y_min, y_max])\n",
    "ax[1].plot([x_min, x_max], [y_min, y_max])\n",
    "\n",
    "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n",
    "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n",
    "\n",
    "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n",
    "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n",
    "\n",
    "ax[1].legend(loc='upper left')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}