{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b65f2a05",
   "metadata": {},
   "source": [
    "# Chapter 13: Going Deeper -- the Mechanics of PyTorch (Part 3/3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddec26f4",
   "metadata": {},
   "source": [
    "## Higher-level PyTorch APIs: a short introduction to PyTorch-Ignite "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7722db5f",
   "metadata": {},
   "source": [
    "### Setting up the PyTorch model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "220a5ea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn \n",
    "from torch.utils.data import DataLoader \n",
    " \n",
    "from torchvision.datasets import MNIST \n",
    "from torchvision import transforms\n",
    " \n",
    " \n",
    "image_path = './' \n",
    "torch.manual_seed(1) \n",
    " \n",
    "transform = transforms.Compose([ \n",
    "    transforms.ToTensor() \n",
    "]) \n",
    " \n",
    " \n",
    "mnist_train_dataset = MNIST( \n",
    "    root=image_path,  \n",
    "    train=True,\n",
    "    transform=transform,  \n",
    "    download=True\n",
    ") \n",
    " \n",
    "mnist_val_dataset = MNIST( \n",
    "    root=image_path,  \n",
    "    train=False,  \n",
    "    transform=transform,  \n",
    "    download=False \n",
    ") \n",
    " \n",
    "batch_size = 64\n",
    "train_loader = DataLoader( \n",
    "    mnist_train_dataset, batch_size, shuffle=True \n",
    ") \n",
    " \n",
    "val_loader = DataLoader( \n",
    "    mnist_val_dataset, batch_size, shuffle=False \n",
    ") \n",
    " \n",
    " \n",
    "def get_model(image_shape=(1, 28, 28), hidden_units=(32, 16)): \n",
    "    input_size = image_shape[0] * image_shape[1] * image_shape[2] \n",
    "    all_layers = [nn.Flatten()]\n",
    "    for hidden_unit in hidden_units: \n",
    "        layer = nn.Linear(input_size, hidden_unit) \n",
    "        all_layers.append(layer) \n",
    "        all_layers.append(nn.ReLU()) \n",
    "        input_size = hidden_unit \n",
    " \n",
    "    all_layers.append(nn.Linear(hidden_units[-1], 10)) \n",
    "    all_layers.append(nn.Softmax(dim=1)) \n",
    "    model = nn.Sequential(*all_layers)\n",
    "    return model \n",
    " \n",
    " \n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    " \n",
    "model = get_model().to(device)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99dcabd0",
   "metadata": {},
   "source": [
    "### Setting up training and validation engines with PyTorch-Ignite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c972bfa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
    "from ignite.metrics import Accuracy, Loss\n",
    " \n",
    " \n",
    "trainer = create_supervised_trainer(\n",
    "    model, optimizer, loss_fn, device=device\n",
    ")\n",
    " \n",
    "val_metrics = {\n",
    "    \"accuracy\": Accuracy(),\n",
    "    \"loss\": Loss(loss_fn)\n",
    "}\n",
    " \n",
    "evaluator = create_supervised_evaluator(\n",
    "    model, metrics=val_metrics, device=device\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa17c608",
   "metadata": {},
   "source": [
    "### Creating event handlers for logging and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0edafc78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many batches to wait before logging training status\n",
    "log_interval = 100\n",
    " \n",
    "@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))\n",
    "def log_training_loss():\n",
    "    e = trainer.state.epoch\n",
    "    max_e = trainer.state.max_epochs\n",
    "    i = trainer.state.iteration\n",
    "    batch_loss = trainer.state.output\n",
    "    print(f\"Epoch[{e}/{max_e}], Iter[{i}] Loss: {batch_loss:.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1944b444",
   "metadata": {},
   "outputs": [],
   "source": [
    "@trainer.on(Events.EPOCH_COMPLETED)\n",
    "def log_validation_results():\n",
    "    eval_state = evaluator.run(val_loader)\n",
    "    metrics = eval_state.metrics\n",
    "    e = trainer.state.epoch\n",
    "    max_e = trainer.state.max_epochs\n",
    "    acc = metrics['accuracy']\n",
    "    avg_loss = metrics['loss']\n",
    "    print(f\"Validation Results - Epoch[{e}/{max_e}] Avg Accuracy: {acc:.2f} Avg Loss: {avg_loss:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b8cdfa",
   "metadata": {},
   "source": [
    "### Setting up training checkpoints and saving the best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e451c03b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ignite.engine.events.RemovableEventHandle at 0x1060993d0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ignite.handlers import Checkpoint, DiskSaver\n",
    " \n",
    "# We will save in the checkpoint the following:\n",
    "to_save = {\"model\": model, \"optimizer\": optimizer, \"trainer\": trainer}\n",
    " \n",
    "# We will save checkpoints to the local disk\n",
    "output_path = \"./output\"\n",
    "save_handler = DiskSaver(dirname=output_path, require_empty=False)\n",
    " \n",
    "# Set up the handler:\n",
    "checkpoint_handler = Checkpoint(\n",
    "    to_save, save_handler, filename_prefix=\"training\")\n",
    "\n",
    "# Attach the handler to the trainer\n",
    "trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e3c8def7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ignite.engine.events.RemovableEventHandle at 0x106088430>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Store best model by validation accuracy\n",
    "best_model_handler = Checkpoint(\n",
    "    {\"model\": model},\n",
    "    save_handler,\n",
    "    filename_prefix=\"best\",\n",
    "    n_saved=1,\n",
    "    score_name=\"accuracy\",\n",
    "    score_function=Checkpoint.get_default_score_fn(\"accuracy\"),\n",
    ")\n",
    " \n",
    "evaluator.add_event_handler(Events.COMPLETED, best_model_handler)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7092b069",
   "metadata": {},
   "source": [
    "### Setting up TensorBoard as an experiment tracking system"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9dc0368b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ignite.engine.events.RemovableEventHandle at 0x16605ff70>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine\n",
    " \n",
    " \n",
    "tb_logger = TensorboardLogger(log_dir=output_path)\n",
    " \n",
    "# Attach handler to plot trainer's loss every 100 iterations\n",
    "tb_logger.attach_output_handler(\n",
    "    trainer,\n",
    "    event_name=Events.ITERATION_COMPLETED(every=100),\n",
    "    tag=\"training\",\n",
    "    output_transform=lambda loss: {\"batch_loss\": loss},\n",
    ")\n",
    " \n",
    "# Attach handler for plotting both evaluators' metrics after every epoch completes\n",
    "tb_logger.attach_output_handler(\n",
    "    evaluator,\n",
    "    event_name=Events.EPOCH_COMPLETED,\n",
    "    tag=\"validation\",\n",
    "    metric_names=\"all\",\n",
    "    global_step_transform=global_step_from_engine(trainer),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad26fb6b",
   "metadata": {},
   "source": [
    "### Executing the PyTorch-Ignite model training code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7f4d38cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch[1/5], Iter[100] Loss: 1.87\n",
      "Epoch[1/5], Iter[200] Loss: 1.82\n",
      "Epoch[1/5], Iter[300] Loss: 1.67\n",
      "Epoch[1/5], Iter[400] Loss: 1.55\n",
      "Epoch[1/5], Iter[500] Loss: 1.65\n",
      "Epoch[1/5], Iter[600] Loss: 1.59\n",
      "Epoch[1/5], Iter[700] Loss: 1.59\n",
      "Epoch[1/5], Iter[800] Loss: 1.56\n",
      "Epoch[1/5], Iter[900] Loss: 1.63\n",
      "Validation Results - Epoch[1/5] Avg Accuracy: 0.91 Avg Loss: 1.56\n",
      "Epoch[2/5], Iter[1000] Loss: 1.61\n",
      "Epoch[2/5], Iter[1100] Loss: 1.56\n",
      "Epoch[2/5], Iter[1200] Loss: 1.54\n",
      "Epoch[2/5], Iter[1300] Loss: 1.54\n",
      "Epoch[2/5], Iter[1400] Loss: 1.51\n",
      "Epoch[2/5], Iter[1500] Loss: 1.53\n",
      "Epoch[2/5], Iter[1600] Loss: 1.50\n",
      "Epoch[2/5], Iter[1700] Loss: 1.50\n",
      "Epoch[2/5], Iter[1800] Loss: 1.52\n",
      "Validation Results - Epoch[2/5] Avg Accuracy: 0.92 Avg Loss: 1.54\n",
      "Epoch[3/5], Iter[1900] Loss: 1.61\n",
      "Epoch[3/5], Iter[2000] Loss: 1.60\n",
      "Epoch[3/5], Iter[2100] Loss: 1.54\n",
      "Epoch[3/5], Iter[2200] Loss: 1.51\n",
      "Epoch[3/5], Iter[2300] Loss: 1.48\n",
      "Epoch[3/5], Iter[2400] Loss: 1.56\n",
      "Epoch[3/5], Iter[2500] Loss: 1.57\n",
      "Epoch[3/5], Iter[2600] Loss: 1.52\n",
      "Epoch[3/5], Iter[2700] Loss: 1.54\n",
      "Epoch[3/5], Iter[2800] Loss: 1.54\n",
      "Validation Results - Epoch[3/5] Avg Accuracy: 0.93 Avg Loss: 1.53\n",
      "Epoch[4/5], Iter[2900] Loss: 1.53\n",
      "Epoch[4/5], Iter[3000] Loss: 1.49\n",
      "Epoch[4/5], Iter[3100] Loss: 1.51\n",
      "Epoch[4/5], Iter[3200] Loss: 1.51\n",
      "Epoch[4/5], Iter[3300] Loss: 1.54\n",
      "Epoch[4/5], Iter[3400] Loss: 1.50\n",
      "Epoch[4/5], Iter[3500] Loss: 1.58\n",
      "Epoch[4/5], Iter[3600] Loss: 1.59\n",
      "Epoch[4/5], Iter[3700] Loss: 1.50\n",
      "Validation Results - Epoch[4/5] Avg Accuracy: 0.94 Avg Loss: 1.53\n",
      "Epoch[5/5], Iter[3800] Loss: 1.52\n",
      "Epoch[5/5], Iter[3900] Loss: 1.60\n",
      "Epoch[5/5], Iter[4000] Loss: 1.50\n",
      "Epoch[5/5], Iter[4100] Loss: 1.50\n",
      "Epoch[5/5], Iter[4200] Loss: 1.51\n",
      "Epoch[5/5], Iter[4300] Loss: 1.57\n",
      "Epoch[5/5], Iter[4400] Loss: 1.56\n",
      "Epoch[5/5], Iter[4500] Loss: 1.55\n",
      "Epoch[5/5], Iter[4600] Loss: 1.50\n",
      "Validation Results - Epoch[5/5] Avg Accuracy: 0.94 Avg Loss: 1.52\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "State:\n",
       "\titeration: 4690\n",
       "\tepoch: 5\n",
       "\tepoch_length: 938\n",
       "\tmax_epochs: 5\n",
       "\toutput: 1.5042390823364258\n",
       "\tbatch: <class 'list'>\n",
       "\tmetrics: <class 'dict'>\n",
       "\tdataloader: <class 'torch.utils.data.dataloader.DataLoader'>\n",
       "\tseed: <class 'NoneType'>\n",
       "\ttimes: <class 'dict'>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.run(train_loader, max_epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "523acf34",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "Readers may ignore the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "29befadd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook ch13_part3.ipynb to script\n",
      "[NbConvertApp] Writing 4864 bytes to ch13_part3.py\n"
     ]
    }
   ],
   "source": [
    "! python ../.convert_notebook_to_script.py --input ch13_part3.ipynb --output ch13_part3.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}