# coding: utf-8


import torch 
import torch.nn as nn 
from torch.utils.data import DataLoader 
from torchvision.datasets import MNIST 
from torchvision import transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Checkpoint, DiskSaver
from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine

# # Chapter 13: Going Deeper -- the Mechanics of PyTorch (Part 3/3)

# ## Higher-level PyTorch APIs: a short introduction to PyTorch-Ignite 

# ### Setting up the PyTorch model



 
 
 
image_path = './' 
torch.manual_seed(1) 
 
transform = transforms.Compose([ 
    transforms.ToTensor() 
]) 
 
 
mnist_train_dataset = MNIST( 
    root=image_path,  
    train=True,
    transform=transform,  
    download=True
) 
 
mnist_val_dataset = MNIST( 
    root=image_path,  
    train=False,  
    transform=transform,  
    download=False 
) 
 
batch_size = 64
train_loader = DataLoader( 
    mnist_train_dataset, batch_size, shuffle=True 
) 
 
val_loader = DataLoader( 
    mnist_val_dataset, batch_size, shuffle=False 
) 
 
 
def get_model(image_shape=(1, 28, 28), hidden_units=(32, 16)): 
    input_size = image_shape[0] * image_shape[1] * image_shape[2] 
    all_layers = [nn.Flatten()]
    for hidden_unit in hidden_units: 
        layer = nn.Linear(input_size, hidden_unit) 
        all_layers.append(layer) 
        all_layers.append(nn.ReLU()) 
        input_size = hidden_unit 
 
    all_layers.append(nn.Linear(hidden_units[-1], 10)) 
    all_layers.append(nn.Softmax(dim=1)) 
    model = nn.Sequential(*all_layers)
    return model 
 
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
model = get_model().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


# ### Setting up training and validation engines with PyTorch-Ignite



 
 
trainer = create_supervised_trainer(
    model, optimizer, loss_fn, device=device
)
 
val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(loss_fn)
}
 
evaluator = create_supervised_evaluator(
    model, metrics=val_metrics, device=device
)


# ### Creating event handlers for logging and validation



# How many batches to wait before logging training status
log_interval = 100
 
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss():
    e = trainer.state.epoch
    max_e = trainer.state.max_epochs
    i = trainer.state.iteration
    batch_loss = trainer.state.output
    print(f"Epoch[{e}/{max_e}], Iter[{i}] Loss: {batch_loss:.2f}")




@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results():
    eval_state = evaluator.run(val_loader)
    metrics = eval_state.metrics
    e = trainer.state.epoch
    max_e = trainer.state.max_epochs
    acc = metrics['accuracy']
    avg_loss = metrics['loss']
    print(f"Validation Results - Epoch[{e}/{max_e}] Avg Accuracy: {acc:.2f} Avg Loss: {avg_loss:.2f}")


# ### Setting up training checkpoints and saving the best model



 
# We will save in the checkpoint the following:
to_save = {"model": model, "optimizer": optimizer, "trainer": trainer}
 
# We will save checkpoints to the local disk
output_path = "./output"
save_handler = DiskSaver(dirname=output_path, require_empty=False)
 
# Set up the handler:
checkpoint_handler = Checkpoint(
    to_save, save_handler, filename_prefix="training")

# Attach the handler to the trainer
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)




# Store best model by validation accuracy
best_model_handler = Checkpoint(
    {"model": model},
    save_handler,
    filename_prefix="best",
    n_saved=1,
    score_name="accuracy",
    score_function=Checkpoint.get_default_score_fn("accuracy"),
)
 
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)


# ### Setting up TensorBoard as an experiment tracking system



 
 
tb_logger = TensorboardLogger(log_dir=output_path)
 
# Attach handler to plot trainer's loss every 100 iterations
tb_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED(every=100),
    tag="training",
    output_transform=lambda loss: {"batch_loss": loss},
)
 
# Attach handler for plotting both evaluators' metrics after every epoch completes
tb_logger.attach_output_handler(
    evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)


# ### Executing the PyTorch-Ignite model training code



trainer.run(train_loader, max_epochs=5)


# ---
# 
# Readers may ignore the next cell.