{ "cells": [ { "cell_type": "markdown", "id": "b65f2a05", "metadata": {}, "source": [ "# Chapter 13: Going Deeper -- the Mechanics of PyTorch (Part 3/3)" ] }, { "cell_type": "markdown", "id": "ddec26f4", "metadata": {}, "source": [ "## Higher-level PyTorch APIs: a short introduction to PyTorch-Ignite " ] }, { "cell_type": "markdown", "id": "7722db5f", "metadata": {}, "source": [ "### Setting up the PyTorch model" ] }, { "cell_type": "code", "execution_count": 1, "id": "220a5ea0", "metadata": {}, "outputs": [], "source": [ "import torch \n", "import torch.nn as nn \n", "from torch.utils.data import DataLoader \n", " \n", "from torchvision.datasets import MNIST \n", "from torchvision import transforms\n", " \n", " \n", "image_path = './' \n", "torch.manual_seed(1) \n", " \n", "transform = transforms.Compose([ \n", " transforms.ToTensor() \n", "]) \n", " \n", " \n", "mnist_train_dataset = MNIST( \n", " root=image_path, \n", " train=True,\n", " transform=transform, \n", " download=True\n", ") \n", " \n", "mnist_val_dataset = MNIST( \n", " root=image_path, \n", " train=False, \n", " transform=transform, \n", " download=False \n", ") \n", " \n", "batch_size = 64\n", "train_loader = DataLoader( \n", " mnist_train_dataset, batch_size, shuffle=True \n", ") \n", " \n", "val_loader = DataLoader( \n", " mnist_val_dataset, batch_size, shuffle=False \n", ") \n", " \n", " \n", "def get_model(image_shape=(1, 28, 28), hidden_units=(32, 16)): \n", " input_size = image_shape[0] * image_shape[1] * image_shape[2] \n", " all_layers = [nn.Flatten()]\n", " for hidden_unit in hidden_units: \n", " layer = nn.Linear(input_size, hidden_unit) \n", " all_layers.append(layer) \n", " all_layers.append(nn.ReLU()) \n", " input_size = hidden_unit \n", " \n", " all_layers.append(nn.Linear(hidden_units[-1], 10)) \n", " all_layers.append(nn.Softmax(dim=1)) \n", " model = nn.Sequential(*all_layers)\n", " return model \n", " \n", " \n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " \n", "model = get_model().to(device)\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "markdown", "id": "99dcabd0", "metadata": {}, "source": [ "### Setting up training and validation engines with PyTorch-Ignite" ] }, { "cell_type": "code", "execution_count": 2, "id": "c972bfa2", "metadata": {}, "outputs": [], "source": [ "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n", "from ignite.metrics import Accuracy, Loss\n", " \n", " \n", "trainer = create_supervised_trainer(\n", " model, optimizer, loss_fn, device=device\n", ")\n", " \n", "val_metrics = {\n", " \"accuracy\": Accuracy(),\n", " \"loss\": Loss(loss_fn)\n", "}\n", " \n", "evaluator = create_supervised_evaluator(\n", " model, metrics=val_metrics, device=device\n", ")\n" ] }, { "cell_type": "markdown", "id": "aa17c608", "metadata": {}, "source": [ "### Creating event handlers for logging and validation" ] }, { "cell_type": "code", "execution_count": 3, "id": "0edafc78", "metadata": {}, "outputs": [], "source": [ "# How many batches to wait before logging training status\n", "log_interval = 100\n", " \n", "@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))\n", "def log_training_loss():\n", " e = trainer.state.epoch\n", " max_e = trainer.state.max_epochs\n", " i = trainer.state.iteration\n", " batch_loss = trainer.state.output\n", " print(f\"Epoch[{e}/{max_e}], Iter[{i}] Loss: {batch_loss:.2f}\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "1944b444", "metadata": {}, "outputs": [], "source": [ "@trainer.on(Events.EPOCH_COMPLETED)\n", "def log_validation_results():\n", " eval_state = evaluator.run(val_loader)\n", " metrics = eval_state.metrics\n", " e = trainer.state.epoch\n", " max_e = trainer.state.max_epochs\n", " acc = metrics['accuracy']\n", " avg_loss = metrics['loss']\n", " print(f\"Validation Results - Epoch[{e}/{max_e}] Avg Accuracy: {acc:.2f} Avg Loss: {avg_loss:.2f}\")" ] }, { "cell_type": "markdown", "id": "11b8cdfa", "metadata": {}, "source": [ "### Setting up training checkpoints and saving the best model" ] }, { "cell_type": "code", "execution_count": 5, "id": "e451c03b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<ignite.engine.events.RemovableEventHandle at 0x1060993d0>" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from ignite.handlers import Checkpoint, DiskSaver\n", " \n", "# We will save in the checkpoint the following:\n", "to_save = {\"model\": model, \"optimizer\": optimizer, \"trainer\": trainer}\n", " \n", "# We will save checkpoints to the local disk\n", "output_path = \"./output\"\n", "save_handler = DiskSaver(dirname=output_path, require_empty=False)\n", " \n", "# Set up the handler:\n", "checkpoint_handler = Checkpoint(\n", " to_save, save_handler, filename_prefix=\"training\")\n", "\n", "# Attach the handler to the trainer\n", "trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)" ] }, { "cell_type": "code", "execution_count": 6, "id": "e3c8def7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<ignite.engine.events.RemovableEventHandle at 0x106088430>" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Store best model by validation accuracy\n", "best_model_handler = Checkpoint(\n", " {\"model\": model},\n", " save_handler,\n", " filename_prefix=\"best\",\n", " n_saved=1,\n", " score_name=\"accuracy\",\n", " score_function=Checkpoint.get_default_score_fn(\"accuracy\"),\n", ")\n", " \n", "evaluator.add_event_handler(Events.COMPLETED, best_model_handler)\n" ] }, { "cell_type": "markdown", "id": "7092b069", "metadata": {}, "source": [ "### Setting up TensorBoard as an experiment tracking system" ] }, { "cell_type": "code", "execution_count": 7, "id": "9dc0368b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<ignite.engine.events.RemovableEventHandle at 0x16605ff70>" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine\n", " \n", " \n", "tb_logger = TensorboardLogger(log_dir=output_path)\n", " \n", "# Attach handler to plot trainer's loss every 100 iterations\n", "tb_logger.attach_output_handler(\n", " trainer,\n", " event_name=Events.ITERATION_COMPLETED(every=100),\n", " tag=\"training\",\n", " output_transform=lambda loss: {\"batch_loss\": loss},\n", ")\n", " \n", "# Attach handler for plotting both evaluators' metrics after every epoch completes\n", "tb_logger.attach_output_handler(\n", " evaluator,\n", " event_name=Events.EPOCH_COMPLETED,\n", " tag=\"validation\",\n", " metric_names=\"all\",\n", " global_step_transform=global_step_from_engine(trainer),\n", ")" ] }, { "cell_type": "markdown", "id": "ad26fb6b", "metadata": {}, "source": [ "### Executing the PyTorch-Ignite model training code" ] }, { "cell_type": "code", "execution_count": 8, "id": "7f4d38cf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch[1/5], Iter[100] Loss: 1.87\n", "Epoch[1/5], Iter[200] Loss: 1.82\n", "Epoch[1/5], Iter[300] Loss: 1.67\n", "Epoch[1/5], Iter[400] Loss: 1.55\n", "Epoch[1/5], Iter[500] Loss: 1.65\n", "Epoch[1/5], Iter[600] Loss: 1.59\n", "Epoch[1/5], Iter[700] Loss: 1.59\n", "Epoch[1/5], Iter[800] Loss: 1.56\n", "Epoch[1/5], Iter[900] Loss: 1.63\n", "Validation Results - Epoch[1/5] Avg Accuracy: 0.91 Avg Loss: 1.56\n", "Epoch[2/5], Iter[1000] Loss: 1.61\n", "Epoch[2/5], Iter[1100] Loss: 1.56\n", "Epoch[2/5], Iter[1200] Loss: 1.54\n", "Epoch[2/5], Iter[1300] Loss: 1.54\n", "Epoch[2/5], Iter[1400] Loss: 1.51\n", "Epoch[2/5], Iter[1500] Loss: 1.53\n", "Epoch[2/5], Iter[1600] Loss: 1.50\n", "Epoch[2/5], Iter[1700] Loss: 1.50\n", "Epoch[2/5], Iter[1800] Loss: 1.52\n", "Validation Results - Epoch[2/5] Avg Accuracy: 0.92 Avg Loss: 1.54\n", "Epoch[3/5], Iter[1900] Loss: 1.61\n", "Epoch[3/5], Iter[2000] Loss: 1.60\n", "Epoch[3/5], Iter[2100] Loss: 1.54\n", "Epoch[3/5], Iter[2200] Loss: 1.51\n", "Epoch[3/5], Iter[2300] Loss: 1.48\n", "Epoch[3/5], Iter[2400] Loss: 1.56\n", "Epoch[3/5], Iter[2500] Loss: 1.57\n", "Epoch[3/5], Iter[2600] Loss: 1.52\n", "Epoch[3/5], Iter[2700] Loss: 1.54\n", "Epoch[3/5], Iter[2800] Loss: 1.54\n", "Validation Results - Epoch[3/5] Avg Accuracy: 0.93 Avg Loss: 1.53\n", "Epoch[4/5], Iter[2900] Loss: 1.53\n", "Epoch[4/5], Iter[3000] Loss: 1.49\n", "Epoch[4/5], Iter[3100] Loss: 1.51\n", "Epoch[4/5], Iter[3200] Loss: 1.51\n", "Epoch[4/5], Iter[3300] Loss: 1.54\n", "Epoch[4/5], Iter[3400] Loss: 1.50\n", "Epoch[4/5], Iter[3500] Loss: 1.58\n", "Epoch[4/5], Iter[3600] Loss: 1.59\n", "Epoch[4/5], Iter[3700] Loss: 1.50\n", "Validation Results - Epoch[4/5] Avg Accuracy: 0.94 Avg Loss: 1.53\n", "Epoch[5/5], Iter[3800] Loss: 1.52\n", "Epoch[5/5], Iter[3900] Loss: 1.60\n", "Epoch[5/5], Iter[4000] Loss: 1.50\n", "Epoch[5/5], Iter[4100] Loss: 1.50\n", "Epoch[5/5], Iter[4200] Loss: 1.51\n", "Epoch[5/5], Iter[4300] Loss: 1.57\n", "Epoch[5/5], Iter[4400] Loss: 1.56\n", "Epoch[5/5], Iter[4500] Loss: 1.55\n", "Epoch[5/5], Iter[4600] Loss: 1.50\n", "Validation Results - Epoch[5/5] Avg Accuracy: 0.94 Avg Loss: 1.52\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 4690\n", "\tepoch: 5\n", "\tepoch_length: 938\n", "\tmax_epochs: 5\n", "\toutput: 1.5042390823364258\n", "\tbatch: <class 'list'>\n", "\tmetrics: <class 'dict'>\n", "\tdataloader: <class 'torch.utils.data.dataloader.DataLoader'>\n", "\tseed: <class 'NoneType'>\n", "\ttimes: <class 'dict'>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.run(train_loader, max_epochs=5)" ] }, { "cell_type": "markdown", "id": "523acf34", "metadata": {}, "source": [ "---\n", "\n", "Readers may ignore the next cell." ] }, { "cell_type": "code", "execution_count": 10, "id": "29befadd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook ch13_part3.ipynb to script\n", "[NbConvertApp] Writing 4864 bytes to ch13_part3.py\n" ] } ], "source": [ "! python ../.convert_notebook_to_script.py --input ch13_part3.ipynb --output ch13_part3.py" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }