from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchsummary
from torchsummary import summary

def display_my_images(train_loader, classes):
	images, labels = next(iter(train_loader))
	fig=plt.figure(figsize=(20,8))
	for i in range(20):
		ax=fig.add_subplot(2,10, i+1)
		img=np.squeeze(images[i].numpy())
		img=img/2 +0.5
		img=np.transpose(img, (1, 2, 0))
		ax.imshow(img)
		ax.set_title(str(classes[labels[i].item()]))