{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "STAT 479: Deep Learning (Spring 2019)  \n",
    "Instructor: Sebastian Raschka (sraschka@wisc.edu)  \n",
    "Course website: http://pages.stat.wisc.edu/~sraschka/teaching/stat479-ss2019/\n",
    "GitHub repository: https://github.com/rasbt/stat479-deep-learning-ss19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.1\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1\n",
      "pandas 0.24.0\n",
      "matplotlib 3.0.2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch,pandas,matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ADALINE with Stochastic Gradient Descent (Minibatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load & Prepare a Toy Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      x1   x2   x3   x4  y\n",
       "145  6.7  3.0  5.2  2.3  1\n",
       "146  6.3  2.5  5.0  1.9  1\n",
       "147  6.5  3.0  5.2  2.0  1\n",
       "148  6.2  3.4  5.4  2.3  1\n",
       "149  5.9  3.0  5.1  1.8  1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('./datasets/iris.data', index_col=None, header=None)\n",
    "df.columns = ['x1', 'x2', 'x3', 'x4', 'y']\n",
    "df = df.iloc[50:150]\n",
    "df['y'] = df['y'].apply(lambda x: 0 if x == 'Iris-versicolor' else 1)\n",
    "df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assign features and target\n",
    "\n",
    "X = torch.tensor(df[['x2', 'x3']].values, dtype=torch.float)\n",
    "y = torch.tensor(df['y'].values, dtype=torch.int)\n",
    "\n",
    "# Shuffling & train/test split\n",
    "\n",
    "torch.manual_seed(123)\n",
    "shuffle_idx = torch.randperm(y.size(0), dtype=torch.long)\n",
    "\n",
    "X, y = X[shuffle_idx], y[shuffle_idx]\n",
    "\n",
    "percent70 = int(shuffle_idx.size(0)*0.7)\n",
    "\n",
    "X_train, X_test = X[shuffle_idx[:percent70]], X[shuffle_idx[percent70:]]\n",
    "y_train, y_test = y[shuffle_idx[:percent70]], y[shuffle_idx[percent70:]]\n",
    "\n",
    "# Normalize (mean zero, unit variance)\n",
    "\n",
    "mu, sigma = X_train.mean(dim=0), X_train.std(dim=0)\n",
    "X_train = (X_train - mu) / sigma\n",
    "X_test = (X_test - mu) / sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFpFJREFUeJzt3X+MXXWZx/HP02FIJ6Gh0mli6W2ZUkgFp50SR0rTCBurKRh+iUpAA6JiXSOUFbfapgRqBcXUYGIw1pKqIdsFSsQRgXVE6y5ZshBby/ZHSi0l1s5AIg5bIOkQptNn/7gznU5nOnPvnHPv+Z7veb8SMtwzt9/7fM+9PNx+n/N8j7m7AADxmJR1AACAdJHYASAyJHYAiAyJHQAiQ2IHgMiQ2AEgMiR2AIgMiR0AIkNiB4DInJbFizY3N3tLS0sWLw0AubV9+/Z/uPv08Z6XSWJvaWnRtm3bsnhpAMgtMztYyfNYigGAyJDYASAyJHYAiEwma+yj6evrU1dXl959992sQwnG5MmTVSqV1NjYmHUoAHIkmMTe1dWlKVOmqKWlRWaWdTiZc3f19PSoq6tLc+bMyTocADkSzFLMu+++q2nTppHUB5iZpk2bxt9gAFQtmMQuiaR+Es4HgIkIKrEjZ3ZukX7YKq2dWv65c0vWEQEQiX1ca9eu1Q9+8IOajL19+3bNnz9f5513nlasWKFc3X925xbpNyuktw5J8vLP36wguQMBILFn6Ktf/ao2btyo/fv3a//+/frtb3+bdUiV+8M6qa93+LG+3vJxAJnKbWLv2NGtJfdv1ZxVT2vJ/VvVsaM78ZgPP/ywFixYoLa2Nt10000jfv/QQw/pwx/+sNra2vSpT31KR44ckSQ9/vjjam1tVVtbmy699FJJ0p49e3TxxRdr4cKFWrBggfbv3z9srNdff11vv/22Fi9eLDPTzTffrI6OjsRzqJu3uqo7DqBugrncsRodO7q1+old6u3rlyR1H+7V6id2SZKuvWjmhMbcs2eP7rvvPj3//PNqbm7Wm2++OeI51113nb785S9Lku666y5t2rRJt99+u9atW6fOzk7NnDlThw8fliRt2LBBd9xxhz73uc/pvffeU39//7Cxuru7VSqVjj8ulUrq7k7+P6e6ObM0sAwzynEAmcrlN/b1nfuOJ/VBvX39Wt+5b8Jjbt26VZ/+9KfV3NwsSTrrrLNGPGf37t36yEc+ovnz52vz5s3as2ePJGnJkiW65ZZb9NBDDx1P4IsXL9Z3v/tdff/739fBgwfV1NQ0bKzR1tNzdRXM0rulxuFzUmNT+TiATOUysb92uLeq45Vw93ET6y233KIHH3xQu3bt0j333HP8GvMNGzbo3nvv1aFDh7Rw4UL19PTos5/9rJ588kk1NTVp2bJl2rp167CxSqWSurqGli26urp09tlnTzj+ultwvXTVj6QzZ0my8s+rflQ+DiBTuUzsZ09tqup4JZYuXaotW7aop6dHkkZdinnnnXc0Y8YM9fX1afPmzcePHzhwQIsWLdK6devU3NysQ4cO6dVXX9W5556rFStW6Oqrr9bOnTuHjTVjxgxNmTJFL7zwgtxdDz/8sK655poJx5+JBddLX98trT1c/klSB4KQy8S+ctk8NTU2DDvW1NiglcvmTXjMD37wg1qzZo0uu+wytbW16c477xzxnO985ztatGiRPv7xj+sDH/jAUDwrV2r+/PlqbW3VpZdeqra2Nj322GNqbW3VwoUL9fLLL+vmm28eMd5PfvIT3XrrrTrvvPM0d+5cXXHFFROOHwAGWRbXTre3t/vJN9rYu3evLrjggorH6NjRrfWd+/Ta4V6dPbVJK5fNm3DhNGTVnhcA8TKz7e7ePt7zcnlVjFS++iXGRI4J2LmlfP38W13lq3KW3s2yEAott4kdkDTUATvYLDXYASuR3FFYuVxjB46jAxYYgcSOfKMDFhiBxI58O1WnKx2wKDASO/KNDlhgBBL7OGq5be+aNWs0a9YsnXHGGTUZvxDogAVG4KqYDF111VW67bbbdP7552cdSr4tuJ5EDpwgv9/Ya3D3nnpu2ytJl1xyiWbMmJE4bgA4UT6/sdfg2uV6b9sLALWS+Bu7mc0ysz+a2V4z22Nmd6QR2JhqcO1yvbftBYBaSWMp5qikb7j7BZIukfQ1M7swhXFPrQbXLtd7214AqJXEid3dX3f3Pw/8+zuS9kqq7SYuNbh2ud7b9gJAraRaPDWzFkkXSXoxzXFHqMG1y1ls2/vNb35TpVJJR44cUalU0tq1ayccPwAMSm3bXjM7Q9J/SbrP3Z8Y5ffLJS2XpNmzZ3/o4MGDw35f9fa0BdnRj217AQyq67a9ZtYo6ZeSNo+W1CXJ3TdK2iiV92NP/KJcuwwgdBl9AU2c2K1ccdwkaa+7P5A8JACIQIZbSqexxr5E0k2SPmpmLw3884mJDJTF3ZxCxvkAcizDLaUTf2N39/+WNPZ1ghWYPHmyenp6NG3atHEvOywCd1dPT48mT56cdSgAJiLDLaWD6TwtlUrq6urSG2+8kXUowZg8ebJKJbafBXLpzFJ5+WW04zUWTGJvbGzUnDlzsg4DRVSQK6xQZ0vvHr7GLtVtS+lgEjuQCe6ZiloZ/Pzk8aoYINfGKnCR2JFURpdl53fbXiAN3DMVESKxo9i4ZyoiRGJHsXHPVESIxI5i456piBDFU4B9hxAZvrEDQGRI7AAQGRI7kIadW6Qftkprp5Z/7tySdUQo8HvCGjuQFN2r4Sn4e8I3diCpDLdnxSkU/D0hsQNJ0b0anoK/JyR2ICm6V8NT8PeExA4kFVL3aoELhsOE9J5kgOIpkFSG27MOU/CC4TChvCcZsSzuq9ne3u7btm2r++sCUfth6ynu2DNL+vru+seD1JnZdndvH+95LMUAsSh4wRBDSOxALApeMMQQEjsQi4IXDDGExA7Egi2IMYCrYoCYsAUxxDd2AIgOiR0AIkNiLyo6FIFoscZeRHQoAlHjG3sRFXxLUyB2JPYiokMRiFoqid3MfmZmfzczNqTIAzoUgail9Y39F5IuT2ks1BodivFKWhR/6k7p22dJa88s/3zqztrEiZpKpXjq7s+ZWUsaY6EOCr6labSSFsWfulPatmnosfcPPb7ygXRjRU2ltm3vQGJ/yt1bx3su2/YCNZB0295vn1VO5iezBumeN5PHh8SC27bXzJab2TYz2/bGG2/U62WB4khaFB8tqY91HMGqW2J3943u3u7u7dOnT6/XywLFkbQobg3VHUewuNwRExdK92rSgl8o80gaR9Ki+Iduqe74qYRyPgssleKpmT0i6Z8kNZtZl6R73H3T2H8KuRZK92rSgl8o80gjjqRF8cHztf0X5fNoDeWkXk3hNJTzWXDc8xQTE8r9NZMW/EKZRyhxJBXLPAIVXPEUkQmlezVpwS+UeYQSR1KxzCPnSOyYmFC6V5MW/EKZRyhxJBXLPHKOxI6JCaV7NWnBL615JC3gLr1bajh9+LGG0/PXDRzK56LgSOyYmFDur3nlA1L7l4a+oVtD+XGlBb805jFYwB1c/hks4Fab3E+ud2VQ/0oslM9FwVE8BZJKo2OToiMqQPEUqJc0OjYpOiJFJHYgqTQ6Nik6IkUkdiCpNDo2l94tTWocfmxSYz6LjnSeZo57ngJJpdGxKUlmYz/OAzpPg0DxFAhBLMXTWOYRKIqnQJ7EUjyNZR45R2IHQhBL8TSWeeQciR3ZotBWllbHZtbnk87TIFA8RXYotA1J4z60IZxP7qcbBIqnyA6FtnRxPqNH8RTho9CWLs4nBpDYkR0KbenifGIAiR3ZianQlnXRUorrfCIRiqfITiyFthCKlie+Vt7PJxKjeAokRdESdULxFKgXipYIDIkdSIqiJQJDYkemDvz8Kzq69n3ye87U0bXv04Gff6X6QbIuXFK0RGBI7MjMgZ9/Ref+9VGdpmMyk07TMZ3710erS+6Dhcu3DknyocJlPZM79/lEYCieIjNH175Pp+nYyOOapNPW/l9lg1C4RIFQPEXwGnxkUh/r+KgoXAIjkNiRmX4b/eN3quOjonAJjEBir1bWhbq0BDCPg+dcr5NXAt3LxysWSOGyY0e3lty/VXNWPa0l929Vx47uur4+cCISezVCKNSlIZB5zP3CT/Vqyw06qklyL6+tv9pyg+Z+4aeVDxJA4bJjR7dWP7FL3Yd75ZK6D/dq9RO7SO7IDMXTasRSqItlHoFYcv9WdR/uHXF85tQmPb/qoxlEhFjVtXhqZpeb2T4ze8XMVqUxZpBiKdTFMo9AvDZKUh/rOFBriRO7mTVI+rGkKyRdKOlGM7sw6bhBiqVQF8s8AnH21KaqjgO1lsY39oslveLur7r7e5IelXRNCuOGJ5BCXWJL79bRhsnDDh1tmJy/eQzKuBC8ctk8NTU2DDvW1NiglcvmVTUOBVikJY3EPlPSiQu2XQPH4hNAoS4NHf1LtKrvVnUda9YxN3Uda9aqvlvV0b8k69CqF0Ah+NqLZup7183XzKlNMpXX1r933Xxde1Hl/xlQgEWaEhdPzewzkpa5+60Dj2+SdLG7337S85ZLWi5Js2fP/tDBgwcTvS4mLqpiXySF4KjeE9RMPYunXZJmnfC4JOm1k5/k7hvdvd3d26dPn57Cy2Kioir2RVIIjuo9QebSSOx/knS+mc0xs9Ml3SDpyRTGRY1EVeyLpBAc1XuCzCVO7O5+VNJtkjol7ZW0xd33JB0XtZNWsS8IS++WGk4ffqzh9PoXghMWcCnAIk2p3PPU3Z+R9EwaY6H2Bot66zv36bXDvTp7apNWLptXVbEvKKPtS1BPKdzzNI33ZLAA29vXL2moAHvi+CgGOk+RbyEUT0OIQRRgi4Bte1EMIRRPQ4hBFGAxhMSOfAuheBpCDKIAiyEkduRbCF20KXUkJy18rlw2T42TbHgYk6yqAizF1ziQ2JFrQXTRptCRnFrnqY3zuB4xIHMUT5FrsRQM05hH0jFiOZcxo3iKQoilYJjGPJKOEcu5BIkdORdLwTCNeSQdI5ZzCRI7cm7lsnlqbDipYNiQv4JhGp2nSceIqiO54FLpPAUydXKZqIqyUSjdmml0niYdI7qO5AKjeIpco2CIIqF4ikKgYAiMRGJHrlEwBEZijb1KHTu641iD3LlF+sO68n4mZ5bKXZI5u8WfVC74rXz8f9V3bGhJsZpuy6R/ftBdHbv0yIuH1O+uBjPduGiW7r12flVjpCGEz2cIMRQdib0KoRTaEkthm9mgJOi2TOPP39WxS//2wt+OP+53P/64nsk9hM9nCDGApZiqrO/cd/wDO6i3r1/rO/dlFNEE/WHdUFIf1NdbPp4z6zv3qa9/+AUAff1e8XuS9M9L0iMvjrJl7xjHayWEz2cIMYDEXpVoCm2BbDObhhCKp/2nuLLsVMdrJYTPZwgxgMRelWgKbYFsM5uGEIqnDTb62s2pjtdKCJ/PEGIAib0q0dyXMpBtZtMQQrfljYtmVXW8VkLoHA0hBlA8rUo096UcLJAmuComiHkojG7L9nPO0r+/8DcdO+HYpIHj9RRC52gIMYDO07qLpdMxlnmkgXOBeqHzNFCxFJdimUcaOBcIDYm9zmIpLsUyjzRwLhAaEnudxVJcimUeUrnBaO7qZ9Sy6mnNXf2M7urYVdWfj+lcIA4UT+ssluJSLPNIo2s0lnOBeFA8RaHNXf3MqI1EDWY68L1PZBARcGoUT4EKhNI1CqSJxI5CC6VrFEgTiR2FllbXaAhduMAgEjsKrf2cs9Qwafi384ZJVlXX6GAXbvfhXrmGunBJ7sgKiR2Ftr5zn/qPDV9P7z9W3ba9bFWL0CRK7Gb2GTPbY2bHzGzcSi0QmjS6Ruk8RWiSfmPfLek6Sc+lEAtQd2l0jdJ5itAkSuzuvtfd+fsmJiyNomOSMdLoGl25bJ4aG4av0zc2VH/fVCAtdes8NbPlkpZL0uzZs+v1sghYGlv/Jh0jta7Rky975zJ4ZGjczlMz+72k94/yqzXu/uuB5/ynpH9194raSek8hZTOdrchbJkbQgwohko7T8f9xu7uH0snJGC4WAqXIcQAnIjLHZGZWAqXIcQAnCjp5Y6fNLMuSYslPW1mnemEFa6si31pCWEeaRUus94yN4QYgBMlKp66+68k/SqlWIIXQrEvDaHMI43CZQhb5oYQA3Aitu2tAsW+dMcAUB227a0Bin3pjgGgNkjsVaDYl+4YAGqDxF4Fin3pjgGgNrjnaRUo9qU7BoDaoHgKADlB8RQACorEDgCRIbEDQGRI7AAQGRI7AESGxA4AkSGxA0BkaFDKqY4d3TQHARgViT2HQtj6F0C4WIrJofWd+44n9UG9ff1a37kvo4gAhITEnkNsmQtgLCT2HGLLXABjIbHnEFvmAhgLxdMcYstcAGMhsefUtRfNJJEDGBVLMQAQGRI7AEQmV0sxdFsCwPhyk9jptgSAyuRmKYZuSwCoTG4SO92WAFCZ3CR2ui0BoDK5Sex0WwJAZXJTPKXbEgAqkyixm9l6SVdJek/SAUlfcPfDaQQ2GrotAWB8SZdinpXU6u4LJP1F0urkIQEAkkiU2N39d+5+dODhC5JKyUMCACSR5hr7FyU9dqpfmtlyScslafbs2Sm+bDHRhQvgVMZN7Gb2e0nvH+VXa9z91wPPWSPpqKTNpxrH3TdK2ihJ7e3tPqFoIYkuXABjGzexu/vHxvq9mX1e0pWSlro7CbsOxurCJbEDSHpVzOWSviXpMnc/kk5IGA9duADGkvSqmAclTZH0rJm9ZGYbUogJ46ALF8BYkl4Vc567z3L3hQP//HNageHU6MIFMJbcdJ5iCF24AMZCYs8punABnEpuNgEDAFSGxA4AkSGxA0BkSOwAEBkSOwBEhsQOAJEhsQNAZLiOvaDY9heIF4m9gNj2F4gbSzEFNNa2vwDyj8ReQGz7C8SNxF5AbPsLxI3EXkBs+wvEjeJpAbHtLxA3EntBse0vEC+WYgAgMiR2AIgMiR0AIkNiB4DIkNgBIDIkdgCIjLl7/V/U7A1JB+v+whPTLOkfWQdRJ0WZa1HmKRVnrkWZ5znuPn28J2WS2PPEzLa5e3vWcdRDUeZalHlKxZlrUeZZKZZiACAyJHYAiAyJfXwbsw6gjooy16LMUyrOXIsyz4qwxg4AkeEbOwBEhsReATNbb2Yvm9lOM/uVmU3NOqZaMbPPmNkeMztmZtFdZWBml5vZPjN7xcxWZR1PrZjZz8zs72a2O+tYasnMZpnZH81s78Dn9o6sYwoBib0yz0pqdfcFkv4iaXXG8dTSbknXSXou60DSZmYNkn4s6QpJF0q60cwuzDaqmvmFpMuzDqIOjkr6hrtfIOkSSV+L+D2tGIm9Au7+O3c/OvDwBUmlLOOpJXff6+6x3tX6YkmvuPur7v6epEclXZNxTDXh7s9JejPrOGrN3V939z8P/Ps7kvZKKvyNBkjs1fuipP/IOghMyExJh0543CWSQDTMrEXSRZJezDaS7HEHpQFm9ntJ7x/lV2vc/dcDz1mj8l/9NtcztrRVMtdI2SjHuCwsAmZ2hqRfSvoXd38763iyRmIf4O4fG+v3ZvZ5SVdKWuo5v0Z0vLlGrEvSrBMelyS9llEsSImZNaqc1De7+xNZxxMClmIqYGaXS/qWpKvd/UjW8WDC/iTpfDObY2anS7pB0pMZx4QEzMwkbZK0190fyDqeUJDYK/OgpCmSnjWzl8xsQ9YB1YqZfdLMuiQtlvS0mXVmHVNaBgrgt0nqVLnItsXd92QbVW2Y2SOS/kfSPDPrMrMvZR1TjSyRdJOkjw78t/mSmX0i66CyRucpAESGb+wAEBkSOwBEhsQOAJEhsQNAZEjsABAZEjsARIbEDgCRIbEDQGT+HzaMz5IBVbXSAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0')\n",
    "plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEqtJREFUeJzt3X+MXWWdx/HP13FMJ6FpA9PEtne6rdZUcH6UeKU0jZhYTcFYwKqNYGC7qzRLVos/ti6kBLoVFVMDCcHYlGAMsYuOEUcMrrNq3ZA1gk4p23a2dgsktXcgsQ4pkHQIw/DdP+70x7TTzsw5597znOe8XwkZ7unlOd9zQz89fZ7vea65uwAA8Xhb3gUAALJFsANAZAh2AIgMwQ4AkSHYASAyBDsARIZgB4DIEOwAEBmCHQAi8/Y8Ttre3u6LFy/O49QAUFh79uz5m7vPm+p9uQT74sWLNTAwkMepAaCwzOzIdN7HVAwARIZgB4DIEOwAEJlc5tgnMzo6qlqtptdffz3vUoIxa9YsVSoVtba25l0KgAIJJthrtZpmz56txYsXy8zyLid37q7h4WHVajUtWbIk73IAFEgwUzGvv/66LrnkEkJ9nJnpkksu4W8wAGYsmGCXRKifhc8DQBJBBTsARGdfr3R/p7R1bv3nvt6Gn5Jgn8LWrVv1ne98pyFj79mzR11dXVq6dKk2bdokvn8WiMy+XukXm6RXjkry+s9fbGp4uBPsObr11lu1c+dOHT58WIcPH9avfvWrvEsCkKXfbpNGRyYeGx2pH2+gwgZ7394hrbp3t5bc/oRW3btbfXuHUo/5yCOPqLu7Wz09PbrpppvO+fWHHnpIH/jAB9TT06NPfvKTOnHihCTpJz/5iTo7O9XT06OrrrpKkjQ4OKgrrrhCy5cvV3d3tw4fPjxhrJdeekmvvvqqVq5cKTPTzTffrL6+vtTXACAgr9RmdjwjwbQ7zkTf3iHd8dh+jYyOSZKGjo/ojsf2S5Kuv3xhojEHBwf1jW98Q7///e/V3t6ul19++Zz3rFu3Trfccosk6c4779TDDz+sL37xi9q2bZv6+/u1cOFCHT9+XJK0Y8cO3XbbbfrsZz+rN954Q2NjYxPGGhoaUqVSOfW6UqloaCj9H04AAjKnMj4NM8nxBirkHfv2/kOnQv2kkdExbe8/lHjM3bt361Of+pTa29slSRdffPE57zlw4IA++MEPqqurS7t27dLg4KAkadWqVdqwYYMeeuihUwG+cuVKffOb39S3v/1tHTlyRG1tbRPGmmw+nS6YBshh4Qo4ZfVdUuvE3/tqbasfb6BCBvuLx0dmdHw63H3KYN2wYYMefPBB7d+/X3ffffepHvMdO3bonnvu0dGjR7V8+XINDw/rxhtv1OOPP662tjatWbNGu3fvnjBWpVJRrXb6r2O1Wk0LFixIXD8mkdPCFXBK93pp7QPSnA5JVv+59oH68QYqZLAvmNs2o+PTsXr1avX29mp4eFiSJp2Kee211zR//nyNjo5q165dp44///zzWrFihbZt26b29nYdPXpUL7zwgt71rndp06ZNuvbaa7Vv374JY82fP1+zZ8/WU089JXfXI488ouuuuy5x/ZhETgtXwATd66UvH5C2Hq//bHCoSwUN9s1rlqmttWXCsbbWFm1esyzxmO973/u0ZcsWfehDH1JPT4++8pWvnPOer3/961qxYoU++tGP6r3vfe/pejZvVldXlzo7O3XVVVepp6dHP/7xj9XZ2anly5frz3/+s26++eZzxvve976nz3/+81q6dKne/e5365prrklcPyaR08IVkDfLo3e6Wq362V+0cfDgQV166aXTHqNv75C29x/Si8dHtGBumzavWZZ44TRkM/1ccIb7O8+zcNVRv3MCCsbM9rh7dar3FbIrRqp3v8QY5MjQ6rvqc+pnTsc0YeEKyFvqqRgz6zCz35nZQTMbNLPbsigMSC2nhSsgb1ncsb8p6avu/oyZzZa0x8x+7e7/m8HYQDrd6wlylE7qO3Z3f8ndnxn/99ckHZTEHAkA5CTTrhgzWyzpcklPZzkuAGD6Mgt2M7tI0k8lfcndX53k1zea2YCZDRw7diyr0wIAzpJJsJtZq+qhvsvdH5vsPe6+092r7l6dN29eFqdtikZu27tlyxZ1dHTooosuasj4AMopi64Yk/SwpIPufl/6kspj7dq1+uMf/5h3GQAik8Ud+ypJN0n6sJk9O/7PxzIY98IasLlTM7ftlaQrr7xS8+fPT103AJwpdbuju/+3pOZuS3hyc6eTD56c3NxJStza1uxtewGgUQq5V0wjNndq9ra9ANAoxQz2Bmzu1OxtewGgUYoZ7Of79pEU30rS7G17AaBRihnsDfhWkjy27f3a176mSqWiEydOqFKpaOvWrYnrB4CTCrttr/b11ufUX6nV79RX3xXlniBs2wvgpOi37WVzJwCYXDGnYgAA5xVUsOcxLRQyPg8ASQQT7LNmzdLw8DBhNs7dNTw8rFmzZuVdCoCCCWaOvVKpqFariZ0fT5s1a5YqleQtnADKKZhgb21t1ZIlS/IuAwAKL5ipGABANgh2AIgMwQ4AkSHYkU4D9sUHkE4wi6cooAbsiw8gPe7YkVwD9sUHkB7BjuQasC8+gPQIdiTXgH3xAaRHsCO5BuyLDyA9gh3Jda+X1j4gzemQZPWfax9g4TRPdClBdMUgLfbFDwddShjHHTsQC7qUMI5gB2JBlxLGEexALOhSwjiCHYgFXUoYR7AjbqF3iWRZH11KGEdXDOIVepdII+qjSwnijh0xC71LJPT6UFgEO+IVepdI6PWhsAh2xCv0LpHQ60NhEeyI1+q7dO7/4m8Lp0uELhY0CMGOeP3lKUlvnXXwrfHjAaCLBQ1i7t70k1arVR8YGGj6eVEy/3ax5GPnHrcW6e6Xm18PkJKZ7XH36lTv444d8Zos1C90HIgEwY54WcvMjgORyCTYzez7ZvZXMzuQxXhAJt6/YWbHgUhkdcf+A0lXZzQWkI2P3ydVP3f6Dt1a6q8/fl++dQENlsmWAu7+pJktzmIsIFMfv48gR+kwxw4AkWlasJvZRjMbMLOBY8eONeu0AFA6TQt2d9/p7lV3r86bN69ZpwWA0mEqBgAik1W746OS/iBpmZnVzOxzWYwLAJi5rLpibshiHABAekzFAEBkCHYAiAzBDgCRIdgBIDIEOwBEhmAHgMgQ7AAQGYIdACJDsANAZAh2AIgMwQ4AkSHYASAyBDsARIZgB4DIEOwAEBmCHQAiQ7ADQGQIdgCIDMEOAJEh2AEgMgQ7AESGYEfc9vVK93dKW+fWf+7rzbsioOHenncBQMPs65V+sUkaHam/fuVo/bUkda/Pry6gwbhjR7x+u+10qJ80OlI/DkSMYEe8XqnN7DgQCYId8ZpTmdlxIBIEO+K1+i6ptW3isda2+nEgYgQ70gm566R7vdRzo2Qt9dfWUn+dZuE06+sN+fNDYdEVg+RC7zrZ1yv9z79LPlZ/7WP114uuTFZf1tcb+ueHwuKOHcmF3nWSdX2hjweMI9iRXOhdJ1nXF/p4wDiCHcmF3nWSdX2hjweMI9izVLaFsNC7TrKuL/TxgHEEe1ZOLoS9clSSn14Iizncu9dLax+Q5nRIsvrPtQ+Es/CXdX2hjweMM3dv+kmr1aoPDAw0/bwNdX/neKifZU6H9OUDza8HQHTMbI+7V6d6H3fsWWEhDEAgMgl2M7vazA6Z2XNmdnsWYxYOC2EAApE62M2sRdJ3JV0j6TJJN5jZZWnHLRwWwgAEIos79iskPefuL7j7G5J+JOm6DMYtliIshJWta0cq5zWj9LLYUmChpDNXDWuSVmQwbvF0rw8ryM9UxsfXy3jNgLK5Y7dJjp3TamNmG81swMwGjh07lsFpMSNlfHy9jNcMKJtgr0nqOON1RdKLZ7/J3Xe6e9Xdq/PmzcvgtJiRMnbtlPGaAWUT7H+S9B4zW2Jm75D0GUmPZzAuslTGrp0yXjOgDILd3d+U9AVJ/ZIOSup198G04yJjZezaKeM1A8poP3Z3/6WkX2YxFhrk5GLhb7fVpyLmVOoBl3IRsW/vkLb3H9KLx0e0YG6bNq9ZpusvX5hBwRlo0DUDoWNLASTWt3dIdzy2XyOjY6eOtbW26FvrusIJdyAibCmAhtvef2hCqEvSyOiYtvcfyqkiABLBjhRePD4yo+MAmoNgR2IL5rbN6DiA5iDYkdjmNcvU1toy4Vhba4s2r1mWU0UApIy6YlBOJxdIs+yKCbrLBigIgh2pXH/5wsyC9+wum6HjI7rjsf2nzgNgepiKQTDosgGyQbAjGHTZANkg2BEMumyAbBDsCEYjumz69g5p1b27teT2J7Tq3t3q2zuUtkwgeCyeIhhZd9mwGIuyItgRlCy7bC60GEuwI2ZMxSBaLMairAh2RIvFWJQVwY5oseUByoo5dkSrEVseAEVAsCNqWS7GAkXBVAwARIZgB4DIEOwAEBmCHQAiw+IpMAN8EQiKgGAHpom9Z1AUTMUA08QXgaAoCHZgmth7BkVBsAPTxN4zKAqCHZgm9p5BUbB4CkwTe8+gKAh2YAbYewZFwFQMAESGYAeAyBDsABAZ5thLpIyPw5fxmgGCvSTK+Dh8Ga8ZkJiKKY0yPg5fxmsGpJTBbmafNrNBM3vLzKpZFYXslfFx+DJeMyClv2M/IGmdpCczqAUNVMbH4ct4zYCUMtjd/aC78/faAijj4/BlvGZAYvG0NMr4OHwZrxmQJHP3C7/B7DeS3jnJL21x95+Pv+e/JP2Luw9cYJyNkjZK0qJFi95/5MiRpDUDQCmZ2R53n3I9c8o7dnf/SBYFuftOSTslqVqtXvhPEwBAYrQ7AkBk0rY7fsLMapJWSnrCzPqzKQsAkFSqxVN3/5mkn2VUCwAgA3TFIBX2YgHCQ7AjMfZiAcLE4ikSYy8WIEwEOxJjLxYgTAQ7EmMvFiBMBDsSYy8WIEwsnmYo9A6RO/v269Gnj2rMXS1mumFFh+65vivxeOzFAoSJYM9I6B0id/bt1w+f+sup12Pup16nDfcQrg/AaUzFZCT0DpFHnz46o+MAiotgz0joHSJj59nF83zHARQXwZ6R0DtEWsxmdBxAcRHsGQm9Q+SGFR0zOg6guEq9eJplF0voHSInF0iz7IoBEKYpv0GpEarVqg8MnPfLlpri7C4WqX6H/a11XcGEMQCcabrfoFTaqZjQu1gAIKnSBnvoXSwAkFRpgz30LhYASKq0wR56FwsAJFXarpjQu1gAIKnSBrvEPicA4lTaqRgAiBXBDgCRIdgBIDKlnmNHeEL/shKgCAh2BCP0LysBioKpGASDbR6AbBDsCAbbPADZINgRDLZ5ALJBsCMYbPMAZIPF04CVrUOEbR6AbBDsgSprhwjbPADpMRUTKDpEACRFsAeKDhEASRHsgaJDBEBSBHug6BABkBSLp4GiQwRAUqmC3cy2S1or6Q1Jz0v6B3c/nkVhoEMEQDJpp2J+LanT3bsl/Z+kO9KXBABII1Wwu/t/uvub4y+fklRJXxIAII0sF0//UdJ/ZDgeACCBKefYzew3kt45yS9tcfefj79ni6Q3Je26wDgbJW2UpEWLFiUqtmyP2ANAEubu6QYw+3tJ/yRptbufmM5/U61WfWBgYEbnOfsRe6ne/vetdV2EO4BSMLM97l6d6n2ppmLM7GpJ/yrp2umGelI8Yg8A05N2jv1BSbMl/drMnjWzHRnUNCkesQeA6UnVx+7uS7MqZCoL5rZpaJIQ5xF7AJioMFsK8Ig9AExPYbYU4BF7AJiewgS7xCP2ADAdhZmKAQBMD8EOAJEh2AEgMgQ7AESGYAeAyBDsABCZ1JuAJTqp2TFJR5p+4ny0S/pb3kUEiM9lcnwu58dnI/2du8+b6k25BHuZmNnAdHZjKxs+l8nxuZwfn830MRUDAJEh2AEgMgR74+3Mu4BA8blMjs/l/Phspok5dgCIDHfsABAZgr0JzOzTZjZoZm+ZWelX9c3sajM7ZGbPmdntedcTAjP7vpn91cwO5F1LSMysw8x+Z2YHx38P3ZZ3TUVAsDfHAUnrJD2ZdyF5M7MWSd+VdI2kyyTdYGaX5VtVEH4g6eq8iwjQm5K+6u6XSrpS0j/z/8vUCPYmcPeD7s63btddIek5d3/B3d+Q9CNJ1+VcU+7c/UlJL+ddR2jc/SV3f2b831+TdFASX8owBYIdzbZQ0tEzXtfEb1RMg5ktlnS5pKfzrSR8hfoGpZCZ2W8kvXOSX9ri7j9vdj0Bs0mO0ZqFCzKziyT9VNKX3P3VvOsJHcGeEXf/SN41FERNUscZryuSXsypFhSAmbWqHuq73P2xvOspAqZi0Gx/kvQeM1tiZu+Q9BlJj+dcEwJlZibpYUkH3f2+vOspCoK9CczsE2ZWk7RS0hNm1p93TXlx9zclfUFSv+oLYb3uPphvVfkzs0cl/UHSMjOrmdnn8q4pEKsk3STpw2b27Pg/H8u7qNDx5CkARIY7dgCIDMEOAJEh2AEgMgQ7AESGYAeAyBDsABAZgh0AIkOwA0Bk/h9KNAHoCqXDewAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], label='class 0')\n",
    "plt.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], label='class 1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement ADALINE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Adaline1():\n",
    "    def __init__(self, num_features):\n",
    "        self.num_features = num_features\n",
    "        self.weights = torch.zeros(num_features, 1, \n",
    "                                   dtype=torch.float)\n",
    "        self.bias = torch.zeros(1, dtype=torch.float)\n",
    "\n",
    "    def forward(self, x):\n",
    "        netinputs = torch.add(torch.mm(x, self.weights), self.bias)\n",
    "        activations = netinputs\n",
    "        return activations.view(-1)\n",
    "        \n",
    "    def backward(self, x, yhat, y):  \n",
    "        \n",
    "        grad_loss_yhat = y - yhat\n",
    "        \n",
    "        grad_yhat_weights = x\n",
    "        grad_yhat_bias = 1.\n",
    "        \n",
    "        # Chain rule: inner times outer\n",
    "        grad_loss_weights = 2* -torch.mm(grad_yhat_weights.t(),\n",
    "                                         grad_loss_yhat.view(-1, 1)) / y.size(0)\n",
    "\n",
    "        grad_loss_bias = 2* -torch.sum(grad_yhat_bias*grad_loss_yhat) / y.size(0)\n",
    "        \n",
    "        return (-1)*grad_loss_weights, (-1)*grad_loss_bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Training and Evaluation Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################\n",
    "##### Training and evaluation wrappers\n",
    "###################################################\n",
    "\n",
    "def loss(yhat, y):\n",
    "    return torch.mean((yhat - y)**2)\n",
    "\n",
    "\n",
    "def train(model, x, y, num_epochs,\n",
    "          learning_rate=0.01, seed=123, minibatch_size=10):\n",
    "    cost = []\n",
    "    \n",
    "    torch.manual_seed(seed)\n",
    "    for e in range(num_epochs):\n",
    "        \n",
    "        #### Shuffle epoch\n",
    "        shuffle_idx = torch.randperm(y.size(0), dtype=torch.long)\n",
    "        minibatches = torch.split(shuffle_idx, minibatch_size)\n",
    "        \n",
    "        for minibatch_idx in minibatches:\n",
    "\n",
    "            #### Compute outputs ####\n",
    "            yhat = model.forward(x[minibatch_idx])\n",
    "\n",
    "            #### Compute gradients ####\n",
    "            negative_grad_w, negative_grad_b = \\\n",
    "                model.backward(x[minibatch_idx], yhat, y[minibatch_idx])\n",
    "\n",
    "            #### Update weights ####\n",
    "            model.weights += learning_rate * negative_grad_w\n",
    "            model.bias += learning_rate * negative_grad_b\n",
    "            \n",
    "            #### Logging ####\n",
    "            minibatch_loss = loss(yhat, y[minibatch_idx])\n",
    "            print('    Minibatch MSE: %.3f' % minibatch_loss)\n",
    "\n",
    "        #### Logging ####\n",
    "        yhat = model.forward(x)\n",
    "        curr_loss = loss(yhat, y)\n",
    "        print('Epoch: %03d' % (e+1), end=\"\")\n",
    "        print(' | MSE: %.5f' % curr_loss)\n",
    "        cost.append(curr_loss)\n",
    "\n",
    "    return cost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Linear Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    Minibatch MSE: 0.500\n",
      "    Minibatch MSE: 0.341\n",
      "    Minibatch MSE: 0.220\n",
      "    Minibatch MSE: 0.245\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.133\n",
      "    Minibatch MSE: 0.144\n",
      "Epoch: 001 | MSE: 0.12142\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.147\n",
      "    Minibatch MSE: 0.064\n",
      "    Minibatch MSE: 0.079\n",
      "    Minibatch MSE: 0.185\n",
      "    Minibatch MSE: 0.063\n",
      "    Minibatch MSE: 0.135\n",
      "Epoch: 002 | MSE: 0.09932\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.064\n",
      "    Minibatch MSE: 0.128\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.079\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.080\n",
      "Epoch: 003 | MSE: 0.09693\n",
      "    Minibatch MSE: 0.131\n",
      "    Minibatch MSE: 0.146\n",
      "    Minibatch MSE: 0.050\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.106\n",
      "    Minibatch MSE: 0.072\n",
      "    Minibatch MSE: 0.102\n",
      "Epoch: 004 | MSE: 0.09658\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.204\n",
      "    Minibatch MSE: 0.149\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.060\n",
      "    Minibatch MSE: 0.056\n",
      "    Minibatch MSE: 0.069\n",
      "Epoch: 005 | MSE: 0.09657\n",
      "    Minibatch MSE: 0.068\n",
      "    Minibatch MSE: 0.111\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.115\n",
      "    Minibatch MSE: 0.157\n",
      "    Minibatch MSE: 0.074\n",
      "    Minibatch MSE: 0.087\n",
      "Epoch: 006 | MSE: 0.09650\n",
      "    Minibatch MSE: 0.057\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.133\n",
      "    Minibatch MSE: 0.127\n",
      "    Minibatch MSE: 0.062\n",
      "    Minibatch MSE: 0.153\n",
      "    Minibatch MSE: 0.103\n",
      "Epoch: 007 | MSE: 0.09683\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.110\n",
      "    Minibatch MSE: 0.101\n",
      "    Minibatch MSE: 0.065\n",
      "    Minibatch MSE: 0.126\n",
      "    Minibatch MSE: 0.124\n",
      "    Minibatch MSE: 0.076\n",
      "Epoch: 008 | MSE: 0.09648\n",
      "    Minibatch MSE: 0.120\n",
      "    Minibatch MSE: 0.056\n",
      "    Minibatch MSE: 0.100\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.106\n",
      "    Minibatch MSE: 0.075\n",
      "    Minibatch MSE: 0.144\n",
      "Epoch: 009 | MSE: 0.09740\n",
      "    Minibatch MSE: 0.073\n",
      "    Minibatch MSE: 0.071\n",
      "    Minibatch MSE: 0.084\n",
      "    Minibatch MSE: 0.152\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.108\n",
      "    Minibatch MSE: 0.118\n",
      "Epoch: 010 | MSE: 0.09636\n",
      "    Minibatch MSE: 0.058\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.145\n",
      "    Minibatch MSE: 0.081\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.127\n",
      "    Minibatch MSE: 0.115\n",
      "Epoch: 011 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.123\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.085\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.143\n",
      "    Minibatch MSE: 0.081\n",
      "Epoch: 012 | MSE: 0.09718\n",
      "    Minibatch MSE: 0.096\n",
      "    Minibatch MSE: 0.076\n",
      "    Minibatch MSE: 0.149\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.116\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.091\n",
      "Epoch: 013 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.104\n",
      "    Minibatch MSE: 0.107\n",
      "    Minibatch MSE: 0.120\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.045\n",
      "    Minibatch MSE: 0.124\n",
      "Epoch: 014 | MSE: 0.09685\n",
      "    Minibatch MSE: 0.121\n",
      "    Minibatch MSE: 0.051\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.122\n",
      "    Minibatch MSE: 0.030\n",
      "    Minibatch MSE: 0.158\n",
      "    Minibatch MSE: 0.121\n",
      "Epoch: 015 | MSE: 0.09745\n",
      "    Minibatch MSE: 0.080\n",
      "    Minibatch MSE: 0.119\n",
      "    Minibatch MSE: 0.091\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.044\n",
      "    Minibatch MSE: 0.092\n",
      "    Minibatch MSE: 0.180\n",
      "Epoch: 016 | MSE: 0.09693\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.075\n",
      "    Minibatch MSE: 0.184\n",
      "    Minibatch MSE: 0.105\n",
      "    Minibatch MSE: 0.121\n",
      "    Minibatch MSE: 0.066\n",
      "    Minibatch MSE: 0.096\n",
      "Epoch: 017 | MSE: 0.09699\n",
      "    Minibatch MSE: 0.076\n",
      "    Minibatch MSE: 0.050\n",
      "    Minibatch MSE: 0.198\n",
      "    Minibatch MSE: 0.105\n",
      "    Minibatch MSE: 0.054\n",
      "    Minibatch MSE: 0.136\n",
      "    Minibatch MSE: 0.099\n",
      "Epoch: 018 | MSE: 0.09672\n",
      "    Minibatch MSE: 0.158\n",
      "    Minibatch MSE: 0.099\n",
      "    Minibatch MSE: 0.087\n",
      "    Minibatch MSE: 0.070\n",
      "    Minibatch MSE: 0.103\n",
      "    Minibatch MSE: 0.112\n",
      "    Minibatch MSE: 0.081\n",
      "Epoch: 019 | MSE: 0.09638\n",
      "    Minibatch MSE: 0.095\n",
      "    Minibatch MSE: 0.093\n",
      "    Minibatch MSE: 0.111\n",
      "    Minibatch MSE: 0.147\n",
      "    Minibatch MSE: 0.083\n",
      "    Minibatch MSE: 0.102\n",
      "    Minibatch MSE: 0.065\n",
      "Epoch: 020 | MSE: 0.09635\n"
     ]
    }
   ],
   "source": [
    "model = Adaline1(num_features=X_train.size(1))\n",
    "cost = train(model, \n",
    "             X_train, y_train.float(),\n",
    "             num_epochs=20,\n",
    "             learning_rate=0.1,\n",
    "             seed=123,\n",
    "             minibatch_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate ADALINE Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Loss (MSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEKCAYAAAA4t9PUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XucXHV9//HXZ2Z3J8nO5jqzCeQeCMREASFExIIKFdAK+KtYCV4AaVFbFH991J/8HlpU2v4epfXSWnm0oqAoJVitSFTkUlCp5ZIsEEJCSLIJuWyuu7nuJdnr5/fHObuZDLO7Jzs7M7uZ9/PxmMecObf5zNmZ/Zzv+Z7v92vujoiIyFDFSh2AiIiMbkokIiKSFyUSERHJixKJiIjkRYlERETyokQiIiJ5USIREZG8KJGIiEhelEhERCQvFaUOoBhSqZTPmTOn1GGIiIwqL7zwQpO7pwdbrywSyZw5c6irqyt1GCIio4qZbY2yni5tiYhIXpRIREQkL0okIiKSFyUSERHJixKJiIjkRYlERETyokQiIiJ5USIZwEMvNXD/c5FuoxYRKVtKJAP41erdSiQiIoNQIhlAuiZBU0t7qcMQERnRlEgGkE5Wsa+1g67unlKHIiIyYimRDCBdk8Ad9rd2lDoUEZERS4lkAKlkAoBGXd4SEemXEskA0jVhImlWIhER6Y8SyQB6SyRNLbq0JSLSHyWSAahEIiIyuIImEjO7wszWm1m9md2WY/nFZvaimXWZ2TUZ888xs2fNbK2ZrTazD2csm2tmz5vZRjP7sZlVFSr+6kQF46riugVYRGQABUskZhYH7gLeCywElprZwqzVtgE3AA9kzW8DPu7ui4ArgH8ys4nhsjuBb7r7fOAAcFNhPkEglUyoRCIiMoBClkiWAPXuvtndO4AHgaszV3D3Le6+GujJmr/B3TeG0zuBvUDazAy4BPhpuOp9wAcK+BnUKFFEZBCFTCTTge0ZrxvCeSfEzJYAVcAmYApw0N27Btunmd1sZnVmVtfY2Hiib9snlaxSiUREZACFTCSWY56f0A7MTgF+BNzo7j0nsk93v9vdF7v74nQ6fSJvexyVSEREBlbIRNIAzMx4PQPYGXVjMxsP/Ar4krs/F85uAiaaWcVQ9jkUqWSCA22ddKqbFBGRnAqZSFYC88O7rKqAa4HlUTYM138I+KG7/6R3vrs78Bug9w6v64GHhzXqLL23AO9TWxIRkZwKlkjCeoxbgMeAdcB/uPtaM7vDzK4CMLPzzawB+BDwHTNbG27+J8DFwA1mtip8nBMu+wLwl2ZWT1Bnck+hPgNkdJOiehIRkZwqBl9l6Nz9EeCRrHm3Z0yvJLg8lb3d/cD9/exzM8EdYUXR1yix5SgwoVhvKyIyaqhl+yDSvd2kNOvSlohILkokg1APwCIiA1MiGcTYqjg1iQrVkYiI9EOJJIJUTUIlEhGRfiiRRJBOJmhSiUREJCclkghSNVUqkYiI9EOJJAKVSERE+qdEEkEqmeDw0S6OdnaXOhQRkRFHiSSCvm5SWtWWREQkmxJJBOomRUSkf0okEfSWSFRPIiLyRkokEaRq1LpdRKQ/SiQRpJJVgC5tiYjkokQSQaIizvgxFRopUUQkByWSiNI1CZVIRERyUCKJSGO3i4jkpkQSUSqpEomISC5KJBEFJRI1SBQRyaZEElEqmaClvYsjHeomRUQkkxJJRH2NElVPIiJyHCWSiHrHbt+rehIRkeMokUSkEomISG5KJBGp40YRkdyUSCKaEnaTohKJiMjxlEgiqozHmDSuUiUSEZEsSiQnQN2kiIi8kRLJCUgl1U2KiEg2JZITkK5JaEwSEZEsSiQnIJ1M0NTcgbuXOhQRkRFDieQEpGoSHOnsplXdpIiI9CloIjGzK8xsvZnVm9ltOZZfbGYvmlmXmV2TtexRMztoZr/Mmv8DM3vdzFaFj3MK+Rky9bZu19jtIiLHFCyRmFkcuAt4L7AQWGpmC7NW2wbcADyQYxf/CHysn91/3t3PCR+rhinkQWnsdhGRNypkiWQJUO/um929A3gQuDpzBXff4u6rgZ7sjd39SaC5gPGdMJVIRETeaMBEYmZxM/uvIe57OrA943VDOG84/J2ZrTazb5pZYpj2OahUTdC6XSUSEZFjBkwk7t4NtJnZhCHs23Ltcgj7yfZ/gQXA+cBk4As539zsZjOrM7O6xsbGYXhbmFKdIGYqkYiIZKqIsM5R4BUzewJo7Z3p7p8dZLsGYGbG6xnAzhOOMIu77won283s+8Bf9bPe3cDdAIsXLx6W+3XjMWNydZVKJCIiGaIkkl+FjxO1EphvZnOBHcC1wHVD2M9xzOwUd99lZgZ8AFiT7z5PRDB2u4bcFRHpNWgicff7zKwKOCOctd7dOyNs12VmtwCPAXHgXndfa2Z3AHXuvtzMzgceAiYBV5rZV919EYCZ/TfBJaykmTUAN7n7Y8C/m1ma4NLZKuBTJ/qh86HW7SIixxs0kZjZu4D7gC0E/7xnmtn17v70YNu6+yPAI1nzbs+YXklwySvXthf1M/+Swd63kNLJBJsbWwdfUUSkTES5tPV14DJ3Xw9gZmcAy4DzChnYSJUKSyTuTnB1TUSkvEVpR1LZm0QA3H0DUFm4kEa2dDJBR1cPh492lToUEZERIUqJpM7M7gF+FL7+CPBC4UIa2TLHbp8wtmzzqYhInyglkk8Da4HPArcCr1LkCu6RRGO3i4gcb8ASSdhf1j3u/lHgG8UJaWTLLJGIiEi0lu3p8PZfAVLJsJsUlUhERIBodSRbgP8xs+Uc37K9LEsok8ZVEY+ZSiQiIqEoiWRn+IgBNYUNZ+SLxYwp1VUqkYiIhKLUkSTd/fNFimdUSNckaGpRNykiIhCtjuTcIsUyagT9balEIiIC0S5trQrrR37C8XUkPytYVCNcuibBhj0jaswtEZGSiZJIJgP7gMw+rhwo20SSSiZoUjcpIiJAtN5/byxGIKNJuiZBZ7dz6EgnE8fpzmgRKW/91pGY2X9kTN+ZtezxQgY10qktiYjIMQNVts/PmH5P1rJ0AWIZNXpbtyuRiIgMnEgGGp52WIauHa1qexOJGiWKiAxYRzLOzN5KkGzGhtMWPsYWI7iRSh03iogcM1Ai2cWxjhp3c3ynjbsLFtEoMGFsJZVxU6NEEREGSCTu/u5iBjKamJkaJYqIhKKMRyI5BN2kKJGIiCiRDJFKJCIiASWSIUonVSIREYEB6kjMbMDOGt39xeEPZ/RI1VSxr7WDnh4nFlM3KSJSvga6a+vr4fMYYDHwMsGtv2cBzwN/UNjQRrZ0MkF3j3OgrYMp4e3AIiLlqN9LW+7+7vDOra3Aue6+2N3PA94K1BcrwJEqpUaJIiJAtDqSBe7+Su8Ld18DnFO4kEaHdFgKaWpWWxIRKW9RupFfZ2bfA+4n6Brlo8C6gkY1CvT1t9VytMSRiIiUVpREciPwaeDW8PXTwL8WLKJRIqWOG0VEgGjjkRw1s38DHnH39UWIaVSoSVSQqIipmxQRKXuD1pGY2VXAKuDR8PU54dC7ZU3dpIiIBKJUtn8ZWAIcBHD3VcCcAsY0aqibFBGRaImky90PDWXnZnaFma03s3ozuy3H8ovN7EUz6zKza7KWPWpmB83sl1nz55rZ82a20cx+bGYlG+tWJRIRkWiJZI2ZXQfEzWy+mf0L8MxgG5lZHLgLeC+wEFhqZguzVtsG3AA8kGMX/wh8LMf8O4Fvuvt84ABwU4TPUBAqkYiIREsknwEWAe0E//APAZ+LsN0SoN7dN7t7B/AgcHXmCu6+xd1XAz3ZG7v7k0Bz5jwzM+AS4KfhrPuAD0SIpSDSyaCblK7uN4QvIlI2BrxrKyxVfNXdPw988QT3PR3YnvG6AXjbCe4j2xTgoLt3Zexzeq4Vzexm4GaAWbNm5fm2uaVrErjD/rYOamvGFOQ9RERGugFLJO7eDZw3xH3n6skw37HeI+/T3e8Ou3VZnE6n83zb3DTkrohItAaJL4W3+/4EaO2d6e4/G2S7BmBmxusZwM4TjvB4TcBEM6sISyXDsc8h623drrYkIlLOoiSSycA+grqJXg4MlkhWAvPNbC6wA7gWuG4oQfa9qbub2W+AawjqXK4HHs5nn/lQiUREJFrL9huHsmN37zKzW4DHgDhwr7uvNbM7gDp3X25m5wMPAZOAK83sq+6+CMDM/htYACTNrAG4yd0fA74APGhmfwu8BNwzlPiGw7ESiRKJiJSvQROJmY0huMV2EcHYJAC4+ycG29bdHwEeyZp3e8b0SoLLU7m2vaif+ZsJ7ggruepEBeOq4iqRiEhZi3L774+AacDlwO8I/vE3D7hFGVGjRBEpd1ESyenu/tdAq7vfB/wR8JbChjV6qFGiiJS7KImkM3w+aGZvBiagvrb6pJJVKpGISFmLkkjuNrNJwF8Dy4FXgX8oaFSjiEokIlLuoty19b1w8nfAvMKGM/qkkgkOtHXS2d1DZTxKXhYROblEuWvr9lzz3f2O4Q9n9Om9BXhfSwfTJqibFBEpP1FOoVszHt0EvfnOKWBMo4oaJYpIuYtyaevrma/N7GsEdSWCGiWKiAzlov44VFfSJ60SiYiUuSh1JK9wrIfdOJAGVD8S6ru0pRKJiJSpKJ02vj9jugvYkzEeSNkbWxUnmahQiUREylaURJLdHcr4YKDCgLvvH9aIRiG1JRGRchYlkbxIMK7IAYKBpSYSjLUOwSWvsq8vSau/LREpY1Eq2x8FrnT3lLtPIbjU9TN3n+vuZZ9EAFI1VaojEZGyFSWRnB92Bw+Au/8aeGfhQhp90skETSqRiEiZipJImszsS2Y2x8xmm9kXCUZMlFAqmeDw0S6OdnaXOhQRkaKLkkiWEtzy+xDwc6A2nCehvm5SWjV2u4iUnygt2/cDtwKEvQAfdHcfeKvyktlNyvSJY0scjYhIcfVbIjGz281sQTidMLOngHpgj5n9YbECHA36uklRPYmIlKGBLm19GFgfTl8frltLUNH+/woc16iSqlHrdhEpXwMlko6MS1iXA8vcvdvd1xGt/UnZSCWrAJVIRKQ8DZRI2s3szWaWBt4NPJ6xbFxhwxpdEhVxxo+pUIlERMrSQCWLW4GfEtyx9U13fx3AzN4HvFSE2EYVdZMiIuWq30Ti7s8DC3LMfwR45I1blLeUukkRkTKlQcaHSVAiUTsSESk/SiTDJF2jEomIlCclkmGSSiZoae/iSIe6SRGR8hLpNl4zuxCYk7m+u/+wQDGNSpljt8+crJvaRKR8RBlq90fAacAqoPd02wElkgy9Y7fvbVYiEZHyEqVEshhYqP61BpZZIhERKSdR6kjWANOGsnMzu8LM1ptZvZndlmP5xWb2opl1mdk1WcuuN7ON4eP6jPm/Dfe5KnzUDiW24ZbZcaOISDmJUiJJAa+a2Qqg77+ku1810EZmFgfuAt4DNAArzWy5u7+asdo24Abgr7K2nQx8maA05MAL4bYHwlU+4u51EWIvmim93aSoRCIiZSZKIvnKEPe9BKh3980AZvYgcDXQl0jcfUu4rCdr28uBJ8Iu7DGzJ4ArgGVDjKXgKuMxJo2rVIlERMpOlPFIfjfEfU8Htme8bgDelse20zNef9/MuoH/BP52pNTfqJsUESlHg9aRmNkFZrbSzFrMrMPMus3scIR9W455Uf/hD7TtR9z9LcBF4eNjOXdgdrOZ1ZlZXWNjY8S3zY+6SRGRchSlsv3bBEPrbgTGAn8azhtMAzAz4/UMYGfEuPrd1t13hM/NwAMEl9DewN3vdvfF7r44nU5HfNv8qJsUESlHkVq2u3s9EA/HI/k+8K4Im60E5pvZXDOrAq4FlkeM6zHgMjObFA7vexnwmJlVmFkKwMwqgfcT3FU2IqhEIiLlKEple1uYCFaZ2T8Au4DqwTZy9y4zu4UgKcSBe919rZndAdS5+3IzOx94CJgEXGlmX3X3Re6+38z+hiAZAdwRzqsmSCiV4T7/C/juCX7mgknXJDjS2U1rexfVCY39JSLlIcp/u48RlFxuAf43wSWnD0bZea4u59399ozplQSXrXJtey9wb9a8VuC8KO9dCumMtiRKJCJSLqLctbXVzMYCp7j7V4sQ06iVymjdPic1aKFNROSkEOWurSsJ+tl6NHx9jplFresoK2m1bheRMhSlsv0rBHdGHQRw91UEPQFLllRN0LpdY7eLSDmJkki63P1QwSM5CUypThAzaFKJRETKSJQa4TVmdh0QN7P5wGeBZwob1ugUjxmTq6tUIhGRshKlRPIZYBFBh43LgMPA5woZ1GgWtCVRo0QRKR9R7tpqA74YPmQQ6ZqESiQiUlb6TSSD3Zk1WDfy5SqdTLC5sbXUYYiIFM1AJZK3E/TAuwx4ntwdKUqWVFgicXfMdMhE5OQ3UCKZRjAo1VLgOuBXwDJ3X1uMwEardDJBR1cPze1djB9TWepwREQKrt/K9rCDxkfd/XrgAqAe+K2ZfaZo0Y1CvWO3q1GiiJSLASvbzSwB/BFBqWQO8C3gZ4UPa/TqHbu9qbmd09LJEkcjIlJ4A1W23we8Gfg18FV3HzHdtY9kfSUS3bklImVioBLJx4BW4AzgsxkVxwa4u48vcGyjUioZdJOi1u0iUi76TSTuHmnQKznepHFVxGOmEomIlA0li2EWixlTqqtU2S4iZUOJpAA0druIlBMlkgLQ2O0iUk6USAogKJEokYhIeVAiKYBUMkgk7l7qUERECk6JpADSNQk6u51DRzpLHYqISMEpkRRAb1sS1ZOISDlQIikAtW4XkXKiRFIAteq4UUTKiBJJAfR13Ki2JCJSBpRICmDC2Eoq46YSiYiUBSWSAjCzvluARUROdkokBZKuUet2ESkPSiQFom5SRKRcKJEUSFqXtkSkTBQ0kZjZFWa23szqzey2HMsvNrMXzazLzK7JWna9mW0MH9dnzD/PzF4J9/ktyxhxayRJ1VSxr7WDnh51kyIiJ7eCJRIziwN3Ae8FFgJLzWxh1mrbgBuAB7K2nQx8GXgbsAT4splNChf/K3AzMD98XFGgj5CXdDJBd49zoE23AIvIya2QJZIlQL27b3b3DuBB4OrMFdx9i7uvBnqytr0ceMLd97v7AeAJ4AozOwUY7+7PetAj4g+BDxTwMwxZSq3bRaRMFDKRTAe2Z7xuCOfls+30cHoo+yyqdG+jxGaVSETk5FbIRJKr7iJqhUF/20bep5ndbGZ1ZlbX2NgY8W2Hz7ESydGiv7eISDEVMpE0ADMzXs8Adua5bUM4Peg+3f1ud1/s7ovT6XTkoIdLb8eNKpGIyMmukIlkJTDfzOaaWRVwLbA84raPAZeZ2aSwkv0y4DF33wU0m9kF4d1aHwceLkTw+apJVJCoiKmOREROegVLJO7eBdxCkBTWAf/h7mvN7A4zuwrAzM43swbgQ8B3zGxtuO1+4G8IktFK4I5wHsCnge8B9cAm4NeF+gz56OsmRY0SReQkV1HInbv7I8AjWfNuz5heyfGXqjLXuxe4N8f8OuDNwxtpYaRrEiqRiMhJTy3bC2jGpLGs2n6Q15taSx2KiEjBKJEU0OcvP5OKmHHTD1ZyUA0TReQkpURSQLOnVPOdjy1m+4E2Pn3/i3R0Zbe7FBEZ/ZRICmzJ3Mn8/R+fxbOb9/HXP19D0CBfROTkUdDKdgl88LwZvN7Uyrd/U8+8dDWffOdppQ5JRGTYKJEUyV++5wxeb2rl7x99jTmpai5fNK3UIYmIDAtd2iqSWMz4+p+czVkzJvK5B1exZsehUockIjIslEiKaExlnO9+/DwmjavkpvtWsvuQ+uESkdFPiaTIamvGcM8N59NytIub7ltJW0dXqUMSEcmLEkkJvOmU8fzLdW9l3a7DfO7BVRpFUURGNSWSErlkwVS+9EcLefzVPdz56GulDkdEZMh011YJ3fiOOWxuauE7T29mbqqaa5fMKnVIIiInTCWSEjIzvnLlIi6an+JLP1/DM/VNpQ5JROSEKZGUWEU8xl0fOZe5qWo+df8LbGpsKXVIIiInRIlkBBg/ppJ7bzifyniMT/xgJQda1cHjSNDZ3UNru+6qK4Xmo5088sou7vpNPc9sauJoZ3epQ5IBqI5khJg5eRx3f/w8ln73eT55/wv86KYlJCripQ6rLDUf7WTZim3c+/st7G/tYOmSmfzFu0+ndvyYUod2Unu9qZUn1+3hqdf2suL1/XRl3M1YVRFj8exJXHjaFN5+WoqzZ0ygIq7z4JHCyqETwcWLF3tdXV2pw4jk4VU7uPXBVXzw3Bl87UNnEYwoLMWw+9BRvv8/r/PA89tobu/i7fOmMGPSWB56aQfxmHH9hXP45MXzmJJMlDrUSF7bfZgtTa1cfEaacVUj75yxs7uHlVv289S6vTz12l42h+P2nDE1ySULpnLpm2o5o7aGuq37eWbTPp7ZtI91uw4DkExUsGTu5DCxTOFN08YTi+m3MtzM7AV3XzzoekokI883n9jAPz+5kf9zxZn8+btOL3U4J731u5u5++nNLH95B909zvvecgo3XzyPs2ZMBGDrvlb++cmN/PylHYypjPOJd8zlzy6ax4RxlSWO/I1a27v45eqdLFuxnVXbDwJQk6jg6reeytIls1h06oSSxrevpZ3frm/kqdf28vSGRprbu6iKx3j7aVO49E21vPvMWmZOHtfv9vtbO3hu8z6e2dTEM5v2sbkxSD6TxlVywbwpXHh6igtPm8K8VLVOwoaBEkmG0ZZI3J3PPriKX7y8k69cuZB3nlnLrMnjiOuMa9i4O89u3sfdT2/mt+sbGVsZ58Pnz+SmP5jb7z+y+r0t/NN/beCXq3dRM6aCP7toHje+Yw41Y0qfUNbsOMQDK7axfNVOWtq7OL02ydIlszhzag0/e7GBX76yi46uHs6eMYFrl8ziyrNPJZkofCnF3Vm3q5mnXgsuWb20/SDuwTDUly6o5ZIFtbzj9BTVQ4xl96GjPLu5iWfqgxLLjoNHAJg6PsGFp6U4a8YE3KG9q4f2ru7guTNjuquHjn7mt3cF9TIzJo5j1uRxzJoSPM+eMo7Zk6tH5InEcFMiyTDaEgnA0c5uPvq956nbegCAREWM09JJzpia5IxpNZxRW8MZU2uYMWnsSVWkP9LRzabGFur3trBxbzNtHd3Mr63hzGnBI99/fl3dPTy6djd3P72Z1Q2HmFJdxQ0XzuGjF8xmUnVVpH2s23WYbzyxgSde3cOkcZV86p2n8fG3z2FsVXHrtJqPdrL85Z0sW7GNNTsOk6iI8f6zTmXpkpmcN3vScWfkB9s6eOilHTy4Yjvr9zRTXRXnqnOCUspbpk8YtrP3nh5nU2MLdVsPsHLLfp7btI+dYZ9yZ82YwCULarl0wVQWnTr8l6LcnW372/ougz27qYmmluNvXKmKx0hUxEhUxkhUxElUxKiqiJGojJOI984/tqzbnYYDR9i6r42mlvbj9jV+TAWzp1QfSzCTjyWcUyaMPSlO/JRIMozGRALBNeS1Ow+zYU8zG3Y3s2FvCxv3NLMro7PHsZVxTq9NMn9qkjOm1nDG1CTza2uYPnFkJ5iW9i427W1hY5gw6vcE09sPtNH7layIGZXxGEcy7tiZPnEsC6YdSyxnTqthXipJVcXAFa9tHV38pK6B7/1+M9v3H2Fuqpo/vWguHzx3BmMqh5YAXt5+kG88sYHfbWgklUzwF+8+jaVLZg15f1G4O6u2H2TZim384uVdHOnsZsG0Gq572yyuPmc6E8YOfJbs7ry47SAPrtjGL1bv5GhnD4tOHc+1S2Zx9TmnMv4ES1dHO7tZs+MQK7ccoG7Lfl7YdoCDbZ0ATKmu4vw5k7lkQS3vOjNd9JsV3J19rR1UhsmjKh7L6zfR2t7F9gNtbN3XxrZ9bWzb38bW/W1s29dKw4Ejx90cUBk3ZkwKSi+np4Pf5ulTk5xemzzhY1xKSiQZRmsi6c/ho51s3BMklQ17gn/EG/Y0s+fwsTOm6qogwZwyYSyxWND40YCYGWbhM+F8g5iBYcRiABa8NqiIxRhTGWdMZYyxlXHGVMYZWxknURnrmz723LvusfVbO7qp39scxBsmjvo9zX1nqRCcJc5LVwcJsbaG+VOTzK9NMntKNRUxY8fBI7y2O/iMr+0Okuqmxpa+H25FzJiXrubMaeNZMC0oqS2YFiTT/W0d/PCZLfzwua0cbOvk3FkTufni03jPwqnDdsZYt2U/X3t8Pc9t3s8pE8bwmUvm86HFM6gcxruKDh3p5Ocv7WDZim28truZcVVxrjzrVJa+bRZnzxhaieLw0U4efmkHy1Zs59VdhxlbGef9Z53C0rfN4q0zJ+bc54HWDl7YeoC6rUHiWN1wiI7uYAjpeelqFs+exOI5k1k8exJzy6ieorvH2XnwCNv2hwlmXxvb9rfyelMbmxtbaM8YZnva+DHhd/zYd31+bc2IvFSmRJLhZEsk/TnU1hkmlZagFLOnmcbmdpzg7MwdHOgJp3uf3T1rPoDT40GpqL2zp++fRT7GVMY4vTbJ6ekk86fWhIkjyazJ4074Vs6Orh42N7Wwfncz6zOSTMOBI33rVFfF6exxOrp6eM/CqXzy4nksnjM578/Rn2fqm/ja4+t5cdtBZk4ey62XnsEHzjl10M/m7nT3BMe7x52ejNcb9jSzbMU2frV6F+1dPbxl+gSWLpnFlWefMmx1M+7OKzsOsWzFNh5etZO2jm7OnFrD0iUzeftpKdbsOETd1v2s3HKA+r1Bg9nKuPHm6RM4P0wa582eNGruZiu27h6n4UAbG/e0sCGj9F2/t+W40na6JtF3RaH3tzE3VU2iIk48blTEjHjMiJsV7WqDEkmGckkkhdTd4xzt7A4eXT0c6Qim27u6OdLRw9HObo5kLD8aLq+qiPWdfRXjclvz0U427j2WYAA+esFsTq9NFvR9e7k7v13fyNefWM+aHYepSVQQjxs9GYmiuydI2N1h0hjsJ5hMVHB1WJ/x5umFveuqpb2LX4T1Lqsbjg2+Nn5MBedllDbOnjmxoJfwykFPj7Pj4BHq9wYnfpkl9taOgRtgxsKrBfFYmGAyEs1x82PGPdefz6wp/d8JNxAlkgxKJFJs7s5ja/fw+/pGYmYZD4jHgjNZwVYjAAAHLUlEQVTKmNF3dhmz4Edv4bxg2kglq3jPwqklaQeyZsch1u06zFkzJjK/Njmi69xOJu7OrkNH2bi3hW372+jq7qG7x+nqCU5Curqd7p6evtedWa+Pf+7hy1cuYuoQ66eUSDIokYiInLioiUR9DIiISF6USEREJC9KJCIikhclEhERyYsSiYiI5EWJRERE8qJEIiIieVEiERGRvJRFg0QzawS2DnHzFNA0jOEMN8WXH8WXH8WXn5Ee32x3Tw+2UlkkknyYWV2Ulp2lovjyo/jyo/jyM9Lji0qXtkREJC9KJCIikhclksHdXeoABqH48qP48qP48jPS44tEdSQiIpIXlUhERCQvSiQhM7vCzNabWb2Z3ZZjecLMfhwuf97M5hQxtplm9hszW2dma83s1hzrvMvMDpnZqvBxe7HiC99/i5m9Er73GwZ/scC3wuO32szOLWJsZ2Ycl1VmdtjMPpe1TlGPn5nda2Z7zWxNxrzJZvaEmW0Mnyf1s+314Tobzez6Isb3j2b2Wvj3e8jMJvaz7YDfhQLG9xUz25HxN3xfP9sO+FsvYHw/zohti5mt6mfbgh+/YReM5V3eDyAObALmAVXAy8DCrHX+HPi3cPpa4MdFjO8U4NxwugbYkCO+dwG/LOEx3AKkBlj+PuDXgAEXAM+X8G+9m+D++JIdP+Bi4FxgTca8fwBuC6dvA+7Msd1kYHP4PCmcnlSk+C4DKsLpO3PFF+W7UMD4vgL8VYS//4C/9ULFl7X868DtpTp+w/1QiSSwBKh3983u3gE8CFydtc7VwH3h9E+BS82sKGOPuvsud38xnG4G1gHTi/Hew+hq4IceeA6YaGanlCCOS4FN7j7UBqrDwt2fBvZnzc78jt0HfCDHppcDT7j7fnc/ADwBXFGM+Nz9cXfvCl8+B8wY7veNqp/jF0WU33reBoov/L/xJ8Cy4X7fUlEiCUwHtme8buCN/6j71gl/TIeAKUWJLkN4Se2twPM5Fr/dzF42s1+b2aKiBgYOPG5mL5jZzTmWRznGxXAt/f+AS3n8AKa6+y4ITh6A2hzrjJTj+AmCEmYug30XCumW8NLbvf1cGhwJx+8iYI+7b+xneSmP35AokQRylSyyb2eLsk5BmVkS+E/gc+5+OGvxiwSXa84G/gX4eTFjA97h7ucC7wX+wswuzlo+Eo5fFXAV8JMci0t9/KIaCcfxi0AX8O/9rDLYd6FQ/hU4DTgH2EVw+ShbyY8fsJSBSyOlOn5DpkQSaABmZryeAezsbx0zqwAmMLSi9ZCYWSVBEvl3d/9Z9nJ3P+zuLeH0I0ClmaWKFZ+77wyf9wIPEVxCyBTlGBfae4EX3X1P9oJSH7/Qnt7LfeHz3hzrlPQ4hpX77wc+4uEF/WwRvgsF4e573L3b3XuA7/bzvqU+fhXAHwM/7m+dUh2/fCiRBFYC881sbnjWei2wPGud5UDvHTLXAE/190MabuE11XuAde7+jX7WmdZbZ2NmSwj+tvuKFF+1mdX0ThNUyq7JWm058PHw7q0LgEO9l3GKqN8zwVIevwyZ37HrgYdzrPMYcJmZTQov3VwWzis4M7sC+AJwlbu39bNOlO9CoeLLrHP7X/28b5TfeiH9IfCauzfkWljK45eXUtf2j5QHwV1FGwju6PhiOO8Ogh8NwBiCSyL1wApgXhFj+wOC4vdqYFX4eB/wKeBT4Tq3AGsJ7kJ5DriwiPHNC9/35TCG3uOXGZ8Bd4XH9xVgcZH/vuMIEsOEjHklO34ECW0X0ElwlnwTQZ3bk8DG8HlyuO5i4HsZ234i/B7WAzcWMb56gvqF3u9g712MpwKPDPRdKFJ8Pwq/W6sJksMp2fGFr9/wWy9GfOH8H/R+5zLWLfrxG+6HWraLiEhedGlLRETyokQiIiJ5USIREZG8KJGIiEhelEhERCQvSiQiw8DMurN6GB62XmXNbE5mL7IiI01FqQMQOUkccfdzSh2ESCmoRCJSQOHYEnea2YrwcXo4f7aZPRl2MPikmc0K508Nx/p4OXxcGO4qbmbftWA8msfNbGzJPpRIFiUSkeExNuvS1oczlh129yXAt4F/Cud9m6Bb/bMIOj/8Vjj/W8DvPOg88lyC1s0A84G73H0RcBD4YIE/j0hkatkuMgzMrMXdkznmbwEucffNYcebu919ipk1EXTh0RnO3+XuKTNrBGa4e3vGPuYQjEEyP3z9BaDS3f+28J9MZHAqkYgUnvcz3d86ubRnTHej+k0ZQZRIRArvwxnPz4bTzxD0PAvwEeD34fSTwKcBzCxuZuOLFaTIUOmsRmR4jDWzVRmvH3X33luAE2b2PMGJ29Jw3meBe83s80AjcGM4/1bgbjO7iaDk8WmCXmRFRizVkYgUUFhHstjdm0odi0ih6NKWiIjkRSUSERHJi0okIiKSFyUSERHJixKJiIjkRYlERETyokQiIiJ5USIREZG8/H/4Dst403nliAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(range(len(cost)), cost)\n",
    "plt.ylabel('Mean Squared Error')\n",
    "plt.xlabel('Epoch')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare with analytical solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights tensor([[-0.0763],\n",
      "        [ 0.4181]])\n",
      "Bias tensor([0.4888])\n"
     ]
    }
   ],
   "source": [
    "print('Weights', model.weights)\n",
    "print('Bias', model.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analytical weights tensor([[-0.0703],\n",
      "        [ 0.4219]])\n",
      "Analytical bias tensor([0.4857])\n"
     ]
    }
   ],
   "source": [
    "def analytical_solution(x, y):\n",
    "    Xb = torch.cat( (torch.ones((x.size(0), 1)), x), dim=1)\n",
    "    w = torch.zeros(x.size(1))\n",
    "    z = torch.inverse(torch.matmul(Xb.t(), Xb))\n",
    "    params = torch.matmul(z, torch.matmul(Xb.t(), y))\n",
    "    b, w = torch.tensor([params[0]]), params[1:].view(x.size(1), 1)\n",
    "    return w, b\n",
    "\n",
    "w, b = analytical_solution(X_train, y_train.float())\n",
    "print('Analytical weights', w)\n",
    "print('Analytical bias', b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate on Evaluation Metric (Prediction Accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_where(cond, x_1, x_2):\n",
    "    return (cond * x_1) + ((1-cond) * x_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Accuracy: 90.00\n",
      "Test Accuracy: 96.67\n"
     ]
    }
   ],
   "source": [
    "train_pred = model.forward(X_train)\n",
    "train_acc = torch.mean(\n",
    "    (custom_where(train_pred > 0.5, 1, 0).int() == y_train).float())\n",
    "\n",
    "test_pred = model.forward(X_test)\n",
    "test_acc = torch.mean(\n",
    "    (custom_where(test_pred > 0.5, 1, 0).int() == y_test).float())\n",
    "\n",
    "print('Training Accuracy: %.2f' % (train_acc*100))\n",
    "print('Test Accuracy: %.2f' % (test_acc*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision Boundary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAADFCAYAAAAMsRa3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XtwVNedJ/DvTy8k0BMkQFKjB7Z4SgghBezg2JNABtsT24QknrEdO9hJqEmtE2d3TRIHKmEd28HrrZlNlXfjxeOMhzJxJtmxSXbthNjBs56kAmsJWCTxko0RdIuHLPQECT36t3+01EiW1N3Svbfvvd3fTxVVqK+490j06d+9v/M754iqgoiIyC0S7G4AERHRVDBwERGRqzBwERGRqzBwERGRqzBwERGRqzBwERGRqzBwERGRqzBwERGRqzBwERGRqyTZcdHc3FwtKSmx49JEpqirq/tIVfPsbscI9imKBZH2K1sCV0lJCWpra+24NJEpRKTZ7jaMxj5FsSDSfsVUIRERuQoDFxERuQoDFxERuYotY1xEZL2BgQF4vV709fXZ3RRHSU1NhcfjQXJyst1NoWli4LLLM4VAf8/411PSge/7ot8eijlerxcZGRkoKSmBiNjdHEdQVbS1tcHr9aK0tNTu5rifTZ9jTBXaZaL/7FCvE01RX18f5syZw6A1iohgzpw5fAo1i02fYwxcRDGMQWs8/k7cj4GLiIhchWNcRBRVO3bsQHp6Oh5//HHTz11XV4fNmzejt7cXd955J37yk59E5wmLY9ZRxScuIgIA7D3sw9qd+1H6vTewdud+7D3svg/cb3zjG9i1axeamprQ1NSE3/3ud9G5MMeso4qByy4p6VN7nchCew/78MRr9fB19EIB+Dp68cRr9YaD1+7du7FixQpUVlbiwQcfHHf8xRdfxCc+8QlUVlbiC1/4Aq5evQoA+NWvfoXy8nJUVlbi1ltvBQA0NjZi9erVWLlyJVasWIGmpqYx5zp//jy6urpw8803Q0Tw0EMPYe/evYbaT2HY9DnGVKFdjKYPmJogEz237yR6B4bGvNY7MITn9p3ExqrCaZ2zsbERTz/9NP70pz8hNzcXly9fHvc9mzZtwte//nUAwPbt2/HSSy/hm9/8Jp588kns27cPhYWF6OjoAAC88MILeOyxx/DAAw+gv78fQ0Nj2+vz+eDxeIJfezwe+HzsC5ay6bOGT1xuxdQEmailo3dKr0di//79+OIXv4jc3FwAwOzZs8d9T0NDAz71qU+hoqICe/bsQWNjIwBg7dq12Lx5M1588cVggLr55pvxzDPP4Nlnn0VzczPS0tLGnEtVx52fFYSxiYGLyGFEZIGIvCMix0WkUUQes/qaBdlpU3o9EqoaNnBs3rwZzz//POrr6/HDH/4wOL/qhRdewFNPPYVz585h5cqVaGtrw/3334/f/OY3SEtLw4YNG7B///4x5/J4PPB6vcGvvV4vCgoKpt1+ci4GLiLnGQTwH1V1KYCbAPw7EVlm5QW3bliMtOTEMa+lJSdi64bF0z7nunXr8Mtf/hJtbW0AMGGqsLu7G/n5+RgYGMCePXuCr3/wwQdYs2YNnnzySeTm5uLcuXM4ffo0Fi5ciG9961u4++67cfTo0THnys/PR0ZGBg4cOABVxe7du3HPPfdMu/1TwjHrqOIYF5HDqOp5AOeH/94tIscBFAI4ZtU1R8axntt3Ei0dvSjITsPWDYunPb4FAMuXL8e2bdtw2223ITExEVVVVXj55ZfHfM+PfvQjrFmzBsXFxaioqEB3dzcAYOvWrWhqaoKqYt26daisrMTOnTvxyiuvIDk5GfPnz8cPfvCDcdf86U9/GiyHv+OOO3DHHXdMu/1TwnHlqJKJ8sJWq6mpUW56Z9COrBDHOqPXjjglInWqWhOF65QAeBdAuap2fezYFgBbAKCoqKi6uXnsHnzHjx/H0qVLrW6iK/F340yR9iumCt2KqYmYJyLpAP4FwLc/HrQAQFV3qWqNqtbk5YXd7ZwoZjBV6FZMTcQ0EUlGIGjtUdXX7G4PkZMYfuKyowKKKJZJoBTvJQDHVfXv7G4PkdOYkSqMegUUUYxbC+BBAJ8RkSPDf+60u1FETmE4VWhHBRRRLFPVPwLgzFmiSZhanDFcAVUF4OAEx7aISK2I1La2tpp5WSIiiiOmFWdEUgEFYBcQKIc367pEEeP6jo5g5bYm27Ztw+7du9He3o6eHi5/FqtMCVysgCJX4PqOk4uRoH7XXXfh0UcfRVlZmd1NIQuZUVXICigit7MoqEdzWxMAuOmmm5Cfn2+ozeR8ZjxxjVRA1YvIkeHXvq+qb5pwbpquGLmDJveK9rYmFD/MqCpkBZQTMS1GNot0W5Pt27ejo6MDPT092LBhA4Dr25rce++92LRpE4DAtiZPP/00vF4vNm3axHRgHOOST0RkiWhva0Lxg4GL4gfXd4yqaG9rQvGDaxVS/ODY3uRS0icfE50mO7Y1+c53voOf//znuHr1KjweD772ta9hx44dkTWY48KuwW1NYhW3PbFUtLY1idREfYpbd0xuwt8N+4ztuK1JvGNajIhiFFOFsYqpDSKKUQxc8cjNuXw3t90GkVT2xRs7hkcM4Xt+HKYK45Gb53i5ue1Rlpqaira2Nvd9UFtIVdHW1obU1FS7mxI5vufH4RMXUYzyeDzwer3gbgxjpaamwuPxjD9gQWUlWYOBiyhGJScno7S01O5muEecpt3ciIGLpo45d6KxjPaJUP+exuEYF00dc+5EYxntE+xTU8LAFY/cPMfLzW0nmg6+58dhqjAeuTmd5+a2E00H3/Pj8ImLiIhchYGLiIhchYGLpo45dyJzsU9NCce4rBSrZeNubjuRE7FPTQmfuKzEElciItMxcBE5kIj8TEQuiUiD3W0hchoGLiJnehnA7XY3gsiJTBnjEpGfAfgcgEuqWm7GOcnBQo3dhUqDchfZiKnquyJSYnc7iJzIrOKMlwE8D2C3SecjJ+PYnSOIyBYAWwCgqKjI5tbEOd6URZUpqUJVfRfAZTPOFVNY4koWUtVdqlqjqjV5eXl2N4coaqJWDh+Xd4cscSUiMl3UApeq7gKwCwBqamq4JavROV47skIc6zR2fivnn8Xq3DYiihpWFdrF6nEiI+e3sm0cH4uIiLwK4M8AFouIV0S+anebiJyCK2fQ1IXa4pwByBSqep/dbSByKrPK4V8F8BcAckXEC+CHqvqSGecmC4RLM4bDlB4R2ciUwMW7QyIiihaOcRERkaswcNnF6jleTp1D5tR2EZFrsDjDLkbHicKNRYU6f6gxLiIih+MTF0UXy+GJyCAGLiIispyqQtWctSeYKoxHXBB0HFXFha4+NPi6UO/rRIOvE//lS5WYPSvF7qYRudK1wSE0tnSh7kw76prbUXe2HT//2hqUzcswfG4GLqdy8tJITm5bBFQVLZ19qPd2orGlMxioPurpBwAkCHBDXjpau68xcBFFqLX7Gg6dbceh5kCgOurrRP+gHwBQNHsmbrkxFyLmXIuBy6mcPBbk5LZ9jKrC296LBl8gQNX7OtHY0oXLVwJBKjFBUDY3HbctmouKwkxUeLKwND8TM1PYNYgmM+RXNF3qDjxJDf9pbrsKAEhJTECFJwubP1mCVUU5WFWcjbkZqaZen72Txgu1pJODz62qOHv56ph0X0NLJzquDgAAkhIEZfMysH7pXJQXZqG8MAtL52ciLSXR8LWJYll33wD+37lO1DZfRl1zO46c7UD3tUEAQG56CqqLc/DAmiJUF+dgeUEWUpOt7VMMXPEoXKrPynSfSef2+xXNl68GnqB819N9XX2BzpScKFg0LwO3L58fDFJL5mdY3qGI3E5Vce5yL+rOXh5+murAyQtd8CsgAiyel4G7VxagpiQH1UWzsWB2GsSsHGCEGLjikYtSfUAgSH3YdiWQ7vMGnqIafV3BO76UxAQsyc/AX60oQEVhFioKs7BofjpmJDFIUZS4eNz32uAQGnxdqGu+Hqg+6rkGAEifkYSqomxsWFeG6uIcrFyQjYzUZJtbzMBFDjPkV5xu7UFDSyfqvV1o8AUKKK70DwEAUpISsDQ/E/dUFaC8IPAktWheBlKSOLODbOSim8HW7muoa27HobOBsal6byf6hwJFFMVzZuLWslysKs5BdXEOFs3LQGJCdJ+mIsHA5VRWjjMZZVLbBof8+KD1yvXxKF8njp3vwtXhIJWaHAhSX6j2oHz4SerGuelITmSQIorEkF9x6mKgiOLQcEn6x4soHl5bglXFOVhVlIO8jBk2tzgyDFxOFe4Ozs7UxDTOPzDkx/uXeoJBqt7XiePnu9A3ELjTS0tOxPKCTNxbsyAYpG7Im4UkBimiiHX3DeDIuQ7Ungk8UR0+24GeYBHFDNQU5+DLa4qxqjgH5YWZrk2nM3C5lYNTE/2Dfpy62B2cI1Xv68KJ8124NjynY1ZKIpYXZOH+1cWo8GSivCALC/PSHZmSIHKqkSra0SXpJy92QzUwF3Hx/ExsrCpAdXEOaopnw5MT/SIKqzBwxSMT05DXBodw6sLwk1RL4GnqxPnuYM48Y0YSlhdm4sGbilHhCYxJlc6ZhQQGKXIaK7MY4c4dwbX7BobQ2NI5KlBdL6LImJGEquIc3FGej+riHFQuyHJEEYVVGLji0TQ7Yd/AEE5e6B6T7jt1sRsDQ4H1xzJTk1BeGMiZLx9O9xXPnskgRe5gZRYj3LknOH5Js3Cotwx1bxxDXXM7GnxdwRvCkjkzceuiXFQPF1GUzXVmEYVVGLjsvMuy0d7DPjy37yRaOnpRkJ2GrRsWY2NVYfB438AQjp3vCs6Rqvd1oeliNwb9gSCVlZaMisIsfPWWhagozEJ5YSaKZs+MmVQEUTQNqeCkLkCdfxEO+ctQp4twVucBAFL+3IwVhdeLKKqLc5Cb7o4iCqswcNl5l2WTvYd9eOK1evQOBKr3fB29+M7/PIp3T7VCRNDY0ommSz0YGg5Ss2eloLwwC59enDccpLJiKl9OFG1dmoYj/htR51+Eun84iCPXXkQPZgIA8tCOmoRTeCjxLaxKOIXlP/i/ri2isAoDl1tNc5zqyrVBPPXGsWDQGtE/5Mdrh33ITQ8Eqc8um4flBVmo8GShICuVQYqcwcFZjMmoKpr981Cni1DnL8Mh/yKcVA8UCUiAH0uu9OPziX9EdUITquUUPNI6djFaBq1xTAlcInI7gJ8ASATwD6q604zzxjUjOxwP6+4bQGNLV3COVL2vE6c/uoJQW+K8t209gxQ5l0OzGKP1aTIatDTwNOUvw6Gn38ZH/X8PAMjAVVQlNOGOxIOokVOoTPgA6Y9dBHZ8zuZWu4vhwCUiiQD+G4DPAvACeE9EfqOqx4yem0L42J1nl6ahwV+KhoTFqF/yGBqHg9SI+ZmpKC/Mwt2VhfinP58Jro4+WmH29fRfuDEwopgzzSzGpa4+HJJPoba/CHX+RWjQUgwMf7SWJFzEbYvmovrYj1Htr0eZ+JAgo+4cR87t5AUHHMiMJ67VAN5X1dMAICK/AHAPAAYui3Rc7UdjbzHqtRT1/lI0aCmadX7weMGZyygvzMLnqwpR7slCeUHWmBnxxXNmjhnjAgITgLduWAxg4jGwJ16rBwAGryhhFsMGEWQxhvyKExe6gntO1Z1tx7nLvQC+gZSkBFQWZ+GR4hxUF+Vg1Zgiil8YvjZdZ0bgKgRwbtTXXgBrPv5NIrIFwBYAKCoqMuGyJrHwTucK0jALvZO8Hpn2K/1j5kjV+zqHO8o2AIBHLqFCPsS9if8H5fIhyhM+xJwnvCHPORJ8Jnuiem7fyXFjYL0DQ3hu30kGrihgFsMi0xgf6+obwOGzHcElkw6fbQ+um5mXEViJ4is3l6D6nYewfOgYUi4MARcAHAx/bpo+MwLXRAMi40ZRVHUXgF0AUFNTE2KUJcosfFOV9700/heBwC/swwleb+u5NmrdvsCeUr6O64GvaPZMrCjMxv2ri1H+h4dQnnAGOTKN3P4zhdjY34ONAJAKoA/ArwH8NtDJWjrGB1sAY15nKtFSzGJYIcz4mKqiue0qaoefpg41t+PUpesrUSyZH1g3s3p4Xb8xlbV/qJ/4k9BBY2+xxIzA5QWwYNTXHgAtJpzX9Qqy08YEntGvX+ruQ+NwcBoJVuc7+4LfU5o7C1VF2Xjo5mJUFGZheUEWsmaOmgn/rw3Tb1iYDhyq3QBTiVHg7iyGlUzMkPRpMup1Ier8ZajbXYtDze1oGx77zUhNwqqiHPzVipGVKLKRPoNF2E5hxv/EewDKRKQUgA/A3wC434Tzut7WDYvHjSUlCNBzbRCrn/4DgMDGbKW5s7C6dHZwm47lhZnItHG5lonaPXoMjKlEy7k7i2ElAxmSS5odCFL+Raj1L0LjqCKK0ks9+IvFcwPr+pXk4Ma8dK744mCGA5eqDorIowD2ITCQ/DNVbTTcMhdSVVzo6gtsdjj8JJWcKOgduP49eRkz8MkbcoMroC8ryHTcnVy4MbBIUolkCLMYBg0O+XHiQndwz6m65nZ4r/13AMAM9KNSPsBXE99EdcIprEpowpzHA+PCew/78PA/vscUuMOZ8ompqm8CeNOMc7mFqsLX0YsGX1cwSDX4OoOphgQByuZmYP2yecFdeZfmZ2KWw4LUdESSSuT4lyHMYkxRZ+8ADp+9vufUkbMdwSKKuRkzUFOSg83du1CdcArL5QxSZGjcOZgCdw/3f4pGgarC2947ZnHZxpau4FyoxARB2dx0fGbJXFR4AuNRy/IzkZbizhnv4Trwp5fk4ZUDZ8f9u08vyWPnNwGzGKGpKs60jd7O4zKaLvUEiyhGNh8dWYA2OD/xmX8LOT5mKAXOeVhRxcD1MSN73IwummjwdaFzON+XlCBYNC8Dn106D+WewJPUkvkZSE2ObpAyVGofppOF68DvnGid8LTvnGjFOydaOf5lgnjMYkymb2AIR73Xt/M4dLY9eNOYmZqEVcU5uGtFQbCIYtKsRpjxMUMpcJa8R1VcBy6/X3Gm7QoahpdFqvcG5kt19wV2DE1JTMDi+Rm4syIf5YWZqCjMwuL5GY5Y8HKqpfZjGOzA0+ngHP+iSF3s6huzOWJjS2dw65yFubPwmSVzg09TZhZRhEuBk3PETeDy+xWnP7oyZjzqWEsXuoe3tU5JSsDS+Rm4u7IguAL6onkZSEly5tbxkXSy6Y41hTt3uOPs/BSp0UUUtWcCgWrk/TMjKQGVnmx89ZaFqCnOQVVRNuZYuJ1HuGpagOO3ThGTgWvIrzjd2jMm3XespSs4WDsjKQHLCjKxsaowGKTK5qUjOdGZQWoiocaZAGMDzeE6cLjj4To/xa/OqwM4dK49uGTSkXMduDrcL+dlzkBN8Ww8ckspqotzsCw/M6o3juGqaTl+6xyuD1yDQ36839ozprrvWEtX8M2VlpyIZQWZ+FLNAiwvyESFJws35qUjyUVBaiKhxpkAYwPN4TpwuOPhjlF8UFV8+NGV4LhU7Zl2NF0KjK0mJgiW5mfgS9We4OaIoxd5tsvGqsJJ36ucv+gcrgpcA0N+NF3sCRRMtASC1PHzXegbCGxnPTMlEcsLMvE3qxcES9AX5qXH5JbWVoxDjRaqA4c7Hu7fUmwaKaKobb4cfKJqvxooahopori7sgDVJTmo9IQoonAozl90Dse+c/oH/Th1sXvMmNTxC93oHwwEqfQZSVhekIkH1hQH032lubOmFaQsy1tbuOmd0XGo7Xvr8erBcxhSRaII7luzAE9trIj4+kZ+ZxwniA0XOkcVUZxtR6OvE4P+60UU65fOCxZR3BADK1Fw/qJzOC5w/f1bp7D/xCWcvNCN/qFAkMpITUJ5QRY2f7IE5YVZKC/IRMmcWaZ0BEvz1hZuehdujCvU8e1768ccG1INfh1J8DLyO+M4gTuNFFGMrvYbU0SxIBtfv3VhcDuP2bNSbG6x+UKN7fJ9HV2OC1yXuq8hMy0JD99SEkz3Fc2eaVnu261563BjXKGOXxi1mO9orx48F1HgMvI7c+vvO96MFFHUnbleRDHy/za6iKKmOAdLo1xEYZdQY7trd+7n+zqKHBe4frwp8nSVGcLlra18/DdybiNjXJOtxjqkka3TaiTXz3EC51ENTBUZ2cqjrnlsEcWy/Ez89ScWBIsoCrJSbS+isMtk47d8X0eX4wJXtIXKW1v5+G/03EbGuM539sI/QYyKNPNqZKImJ3nar7d/CEe9Hag72z5hEUV1cQ7uWVmAVcXuLKKwA9/X0RX7z/dhbN2wGGkfW65pJG8dKq1llNFzh2p3uOMzJknrTPb6VK9t1b+l6bnQ2Yf/fbQF/+l/NeKe5/+Iih378Ne7DuA//+4kTn90BeuXzsPOTRV4+z/ciiM/+Ev848Or8ehnyvDJG3IZtCLE93V0xf27MlTe+t//85EJ/03Ej/8h1gQ0o1x9snaHOz7ZzzUyrcDota36tzQ9T7x2FO+cbEVqcmAlii23LkR1cQ6qimKziMIOfF9HV9wHLmDyvLXhx/8QJe8FO/cbTi2Emy9V23wZFzr7oAjcddc2X8bGqkJT0hpG5mpxnld0PbZ+Eb69fhGWFWS6anUYt+H7Onr4Lg7Bysd/q1MLIyXvIwUXIyXv2/fWM60RZ1YuyEblgmwGLYoZfOIKwcrHf6tTC68ePDfp6yMl70xrEEUPJyibh4ErDCsf/60892Sl7SOvM61BFD2coGwuBq4YlSgyYfBKHJ5/w7s/oujhxHtzMekdo25amDPp6yN3f77hycgjd397D3MXVyIrcIKyuQwFLhH5kog0iohfRGrMahQZd6Zt4g5xpq3X0vlpRDTeZBW7nKA8PUafuBoAbALwrgltIROFusPj3R9RdLGS11yGxrhU9TiAuF23zGrhxqFCHQ83V4vL0xBFDycomytqxRkisgXAFgAoKiqK1mVdK1wVUrjjobZgqG2+HHJLFCIyHyt5zRM2VSgib4tIwwR/7pnKhVR1l6rWqGpNXh4/IMMJNw4V7vjGqkL8eFNFYDt0AIXZafjxpgpsrCoMuyUKUazae9iHtTv3o/R7b2Dtzv2mFiRZeW4aK+wTl6quj0ZD7BQq5Wa0bHy65zaybckIbsFAdJ2Td3ugqYn7eVyh3nAADL0ZjZzbyLYl4XALBopHVs6l4jyt6DJaDv95EfECuBnAGyKyz5xmRU+oN5zRsnEj5zaybUk4rHByLk4xsY6VmQZmMaLLaFXh6wBeN6kttpjOGy7SN6ORcxvZtiQcVjg52sgUk/9hd0NijZWZBmYxoivuU4VWlo0bPXe4KiRuLRJ7OMXEOqEqbZ18bhov7pd8CpU2M5pSs/LcRCKyRURqRaS2tZUVoeGEqrR18rlpvLh/4ookbTbdlJqV5yb3EpG3Acyf4NA2Vf11pOdR1V0AdgFATU3NxNsB0Bhu3e2BxhKdZPsLK9XU1GhtbW3Ur0tkFhGpU1XLiidE5F8BPK6qEXUU9imKBZH2q7h/4iIiGsHtftwh7se4iJwkFqaYuBW3+3EPBi4iB1HV11XVo6ozVHWeqm6wu03xgtv9uAcDFxEROInYTTjGZRBz4kTRZaTPbd9bj1cPnsOQKhJFcN+aBXhqYwUATiJ2Ez5xGcCcOFF0Gelz2/fW45UDZzE0XEk9pIpXDpzF9r2BNUI5t9I9GLgMYE6cKLqM9LlXD54L+TonEbsHU4UGMCdOFF1G+tzQJHNWR7/OScTuwMBlAHPiROYLNYZlpM8likwYvBK5LqTrMFVoAHPiROYKN4ZlpM/dt2bBlF4n52LgMoA5cSJzhRvDMtLnntpYgS/fVBR8wkoUwZdvKgpWFZJ7xESq0M6SdObEicwTyRiWkT731MYKBqoY4PonLpakE8WOycaqOG5Mo7k+cLEknSh2cNyYIuH6VCFL0oliRyR72BG5PnCxJJ0otnDcmMIxlCoUkedE5ISIHBWR10Uk26yGRYqpBSKi+GJ0jOstAOWqugLAKQBPGG/S1LAknYiiZe9hH9bu3I/S772BtTv3swjMJoZShar6+1FfHgDwRWPNmR6mFojIaiMVzCPFYCMVzAD4+RNlZlYVPgLgt5MdFJEtIlIrIrWtra0mXtZevAMjig+sYHaOsE9cIvI2gPkTHNqmqr8e/p5tAAYB7JnsPKq6C8AuAKipqZl4tUuX4R0YUfxgBbNzhA1cqro+1HER+QqAzwFYpzrJ8ssxKtQdGAMXkT2sWkmHFczOYbSq8HYA3wVwt6peNadJ7sE7MCJnsXIlHVYwO4fRMa7nAWQAeEtEjojICya0yTW4PA2Rs1g5DsUKZucwWlV4o1kNcaOtGxaPGeMCeAdGZCersyCsYHYG169VaCfegRE5C7Mg8cH1Sz7ZjXdgRM7BLEh8YOAichAReQ7AXQD6AXwA4GFV7bC3Ve7BRXrjAwMXkbO8BeAJVR0UkWcRWEbtuza3yVWYBYl9HOMichBV/b2qDg5/eQCAx872EDkRn7gsZtVkSIoLjwD458kOisgWAFsAoKioKFptMox9goxi4LIQl4SiicTzMmrsE2QGBi4LcUkomkg8L6PGPkFmYOCyEJeEoqkatYzabbG4jBr7BJmBxRkW4mRImoaYXkaNfYLMwMBlIS7KSVOlqjeq6gJVXTn852/tbpOZ2CfIDEwVWoiTIYnGYp8gMzBwWYyTIYnGYp8go5gqJCIiV2HgIiIiV2HgIiIiVxE75jeKSCuA5qhfeGK5AD6yuxEW4M9lrWJVzbO7ESPYp6KCP5f1IupXtgQuJxGRWlWtsbsdZuPPRXaJ1f8j/lzOwVQhERG5CgMXERG5CgPX8OraMYg/F9klVv+P+HM5RNyPcRERkbvwiYuIiFyFgYuIiFyFgQuAiDwnIidE5KiIvC4i2Xa3abpE5HYROSki74vI9+xujxlEZIGIvCMix0WkUUQes7tNFBr7lLO5vU9xjAuAiPwlgP2qOigizwKAqn7X5mZNmYgkAjgF4LMAvADeA3Cfqh6ztWEGiUg+gHxVPSQiGQDqAGx0+88Vy9innM3tfYpPXABU9feqOjj85QEAHjvbY8BqAO+r6mlV7QfwCwD32Nwmw1T1vKoeGv57N4DjALi8uIOxTzmb2/sUA9cmPqJtAAABQElEQVR4jwD4rd2NmKZCAOdGfe2Fi96MkRCREgBVAA7a2xKaAvYpB3Njn4qb/bhE5G0A8yc4tE1Vfz38PdsADALYE822mUgmeC1mcsEikg7gXwB8W1W77G5PvGOfcj+39qm4CVyquj7UcRH5CoDPAVin7h348wJYMOprD4AWm9piKhFJRqCD7VHV1+xuD7FPuZ2b+xSLMxCoGgLwdwBuU9VWu9szXSKShMBA8joAPgQGku9X1UZbG2aQiAiAfwJwWVW/bXd7KDz2KWdze59i4AIgIu8DmAGgbfilA6r6tzY2adpE5E4A/xVAIoCfqerTNjfJMBG5BcC/AagH4B9++fuq+qZ9raJQ2Kecze19ioGLiIhchVWFRETkKgxcRETkKgxcRETkKgxcRETkKgxcRETkKgxcRETkKgxcRETkKv8fqLYXR5slJ/8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 504x216 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "##########################\n",
    "### 2D Decision Boundary\n",
    "##########################\n",
    "\n",
    "w, b = model.weights, model.bias - 0.5\n",
    "\n",
    "x_min = -3\n",
    "y_min = ( (-(w[0] * x_min) - b[0]) \n",
    "          / w[1] )\n",
    "\n",
    "x_max = 3\n",
    "y_max = ( (-(w[0] * x_max) - b[0]) \n",
    "          / w[1] )\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n",
    "\n",
    "ax[0].plot([x_min, x_max], [y_min, y_max])\n",
    "ax[1].plot([x_min, x_max], [y_min, y_max])\n",
    "\n",
    "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n",
    "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n",
    "\n",
    "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n",
    "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n",
    "\n",
    "ax[1].legend(loc='upper left')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}