import torch torch.backends.cudnn.benchmark=True import torch.nn as nn from torch.autograd import Variable import torch.optim as optim import torch.nn.functional as F import torch.nn.init as init import torchvision.transforms as transforms import argparse import os from PIL import Image import scipy.misc import random import copy import math import numpy as np from data.data_loader import cifar10, cifar100, ExemplarDataset from lib.util import moment_update, TransformTwice, weight_norm, mixup_data, mixup_criterion, LabelSmoothingCrossEntropy from lib.augment.cutout import Cutout from lib.augment.autoaugment_extra import CIFAR10Policy from models import * compute_means=True exemplar_means_= [] avg_acc = [] def parse_option(): parser = argparse.ArgumentParser('argument for training') # training hyperparameters parser.add_argument('--batch-size', type=int, default=100, help='batch_size') parser.add_argument('--num-workers', type=int, default=8, help='num of workers to use') parser.add_argument('--epochs', type=int, default=120, help='number of training epochs') parser.add_argument('--epochs-sd', type=int, default=70, help='number of training epochs for self-distillation') parser.add_argument('--val-freq', type=int, default=10, help='validation frequency') # incremental learning parser.add_argument('--new-classes', type=int, default=10, help='number of classes in new task') parser.add_argument('--start-classes', type=int, default=50, help='number of classes in old task') parser.add_argument('--K', type=int, default=2000, help='2000 exemplars for CIFAR-100') # optimization parser.add_argument('--lr', type=float, default=0.1, help='learning rate') parser.add_argument('--lr-min', type=float, default=0.0001, help='lower end of cosine decay') parser.add_argument('--lr-sd', type=float, default=0.1, help='learning rate for self-distillation') parser.add_argument('--lr-ft', type=float, default=0.01, help='learning rate for task-2 onwards') parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') parser.add_argument('--cosine', action='store_true', help='use cosine learning rate') # root folders parser.add_argument('--data-root', type=str, default='./data', help='root directory of dataset') parser.add_argument('--output-root', type=str, default='./output', help='root directory for output') # save and load parser.add_argument('--exp-name', type=str, default='kd', help='experiment name') parser.add_argument('--resume', action='store_true', help='use class moco') parser.add_argument('--resume-path', type=str, default='./checkpoint_0.pth',) parser.add_argument('--save', action='store_true', help='to save checkpoint') # loss function parser.add_argument('--pow', type=float, default=0.66, help='hyperparameter of adaptive weight') parser.add_argument('--lamda', type=float, default=5, help='weighting of classification and distillation') parser.add_argument('--lamda-sd', type=float, default=10, help='weighting of classification and distillation') parser.add_argument('--const-lamda', action='store_true', help='use constant lamda value, default: adaptive weighting') parser.add_argument('--w-cls', type=float, default=1.0, help='weightage of new classification loss') # kd loss parser.add_argument('--kd', action='store_true', help='use kd loss') parser.add_argument('--w-kd', type=float, default=1.0, help='weightage of knowledge distillation loss') parser.add_argument('--T', type=float, default=2, help='temperature scaling for KD') parser.add_argument('--T-sd', type=float, default=2, help='temperature scaling for KD') # self-distillation parser.add_argument('--num-sd', type=int, default=0, help='number of self-distillation generations') parser.add_argument('--sd-factor', type=float, default=5.0, help='weighting between classification and distillation') # mixup parser.add_argument('--mixup', action='store_true', help='use mixup augmentation') parser.add_argument('--mixup-alpha', type=float, default=0.1, help='mixup alpha value') # label smoothing parser.add_argument('--label-smoothing', action='store_true', help='use label smoothing') parser.add_argument('--smoothing-alpha', type=float, default=0.1, help='label smoothing alpha value') # heave augmentation (Auto Augment) parser.add_argument('--aug', action='store_true', help='use heavy augmentation') parser.add_argument('--tsne', action='store_true', help='plot tsne after each incremental step') args = parser.parse_args() return args def train(model, old_model, epoch, lr, tempature, lamda, train_loader, use_sd, checkPoint): tolerance_cnt = 0 step = 0 best_acc = 0 T = args.T model.cuda() old_model.cuda() criterion_ce = nn.CrossEntropyLoss(ignore_index=-1) criterion_ce_smooth = LabelSmoothingCrossEntropy() # for label smoothing # reduce learning rate after first epoch (LowLR) if len(test_classes) // CLASS_NUM_IN_BATCH > 1: lr = args.lr_ft optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=args.weight_decay) if len(test_classes) // CLASS_NUM_IN_BATCH ==1 and use_sd ==True: if args.cosine: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch, eta_min=0.001) else: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60], gamma=0.1) else: if args.cosine: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch, eta_min=args.lr_min) else: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) if len(test_classes) // CLASS_NUM_IN_BATCH > 1: exemplar_set = ExemplarDataset(exemplar_sets, transform=transform_ori) exemplar_loader = torch.utils.data.DataLoader(exemplar_set, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) exemplar_loader_iter = iter(exemplar_loader) old_model.eval() num_old_classes = old_model.fc.out_features for epoch_index in range(1, epoch+1): dist_loss = 0.0 sum_loss = 0 sum_dist_loss = 0 sum_cls_new_loss = 0 sum_cls_old_loss = 0 sum_cls_loss = 0 model.train() old_model.eval() old_model.freeze_weight() for param_group in optimizer.param_groups: print('learning rate: {:.4f}'. format(param_group['lr'])) for batch_idx, (x, x1, target) in enumerate(train_loader): optimizer.zero_grad() # Classification Loss: New task x, target = x.cuda(), target.cuda() targets = target-len(test_classes)+CLASS_NUM_IN_BATCH # use mixup for task-1 if args.mixup: inputs, targets_a, targets_b, lam = mixup_data(x, targets, args.mixup_alpha) inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b)) logits = model(inputs) outputs = logits[:,-CLASS_NUM_IN_BATCH:] cls_loss_new = mixup_criterion(criterion_ce, outputs, targets_a, targets_b, lam) # use label smoothing for task-1 elif args.label_smoothing: logits = model(x) cls_loss_new = criterion_ce_smooth(logits[:,-CLASS_NUM_IN_BATCH:], targets, args.smoothing_alpha) else: logits = model(x) cls_loss_new = criterion_ce(logits[:,-CLASS_NUM_IN_BATCH:], targets) loss = args.w_cls*cls_loss_new sum_cls_new_loss += cls_loss_new.item() # use fixed lamda value or adaptive weighting if args.const_lamda: factor = args.lamda elif use_sd: factor = args.lamda_sd else: factor = ((len(test_classes)/CLASS_NUM_IN_BATCH)**(args.pow))*args.lamda # while using self-distillation if len(test_classes) // CLASS_NUM_IN_BATCH == 1 and use_sd: if args.kd: with torch.no_grad(): dist_target = old_model(x) logits_dist = logits T_sd = args.T_sd dist_loss = nn.KLDivLoss()(F.log_softmax(logits_dist/T_sd, dim=1), F.softmax(dist_target/T_sd, dim=1)) * (T_sd*T_sd) # best model sum_dist_loss += dist_loss.item() loss += factor*args.w_kd*dist_loss # Distillation : task-2 onwards if len(test_classes) // CLASS_NUM_IN_BATCH > 1: if args.kd: with torch.no_grad(): dist_target = old_model(x) logits_dist = logits[:, :-CLASS_NUM_IN_BATCH] T = args.T dist_loss_new = nn.KLDivLoss()(F.log_softmax(logits_dist/T, dim=1), F.softmax(dist_target/T, dim=1)) * (T*T) try: batch_ex = next(exemplar_loader_iter) except: exemplar_loader_iter = iter(exemplar_loader) batch_ex = next(exemplar_loader_iter) # Classification loss: exemplar classes loss x_old, target_old = batch_ex x_old , target_old = x_old.cuda(), target_old.cuda() logits_old = model(x_old) old_classes = len(test_classes) - CLASS_NUM_IN_BATCH cls_loss_old = criterion_ce(logits_old, target_old) loss += cls_loss_old sum_cls_old_loss += cls_loss_old.item() if args.kd: # KD exemplar with torch.no_grad(): dist_target_old = old_model(x_old) logits_dist_old = logits_old[:, :-CLASS_NUM_IN_BATCH] dist_loss_old = nn.KLDivLoss()(F.log_softmax(logits_dist_old/T, dim=1), F.softmax(dist_target_old/T, dim=1)) * (T*T) # best model dist_loss = dist_loss_old + dist_loss_new sum_dist_loss += dist_loss.item() loss += factor*args.w_kd*dist_loss sum_loss += loss.item() loss.backward() optimizer.step() step += 1 if (batch_idx + 1) % checkPoint == 0 or (batch_idx + 1) == len(trainLoader): print('==>>> epoch: {}, batch index: {}, step: {}, train loss: {:.3f}, dist_loss: {:3f}, cls_new_loss: {:.3f}, cls_old_loss: {:.3f}'. format(epoch_index, batch_idx + 1, step, sum_loss/(batch_idx+1), sum_dist_loss/(batch_idx+1), sum_cls_new_loss/(batch_idx+1), sum_cls_old_loss/(batch_idx+1))) scheduler.step() def evaluate_net(model, transform, train_classes, test_classes): model.eval() train_set = cifar100(root=args.data_root, train=False, classes=train_classes, download=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=4) total = 0.0 correct = 0.0 compute_means = True for j, (_, images, labels) in enumerate(train_loader): _, preds = torch.max(torch.softmax(model(images.cuda()), dim=1), dim=1, keepdim=False) labels = [y.item() for y in labels] np.asarray(labels) total += preds.size(0) correct += (preds.cpu().numpy() == labels).sum() # Train Accuracy print ('correct: ', correct, 'total: ', total) print ('Train Accuracy : %.2f ,' % (100.0 * correct / total)) test_set = cifar100(root=args.data_root, train=False, classes=test_classes, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4) total = 0.0 correct = 0.0 for j, (_, images, labels) in enumerate(test_loader): out = torch.softmax(model(images.cuda()), dim=1) _, preds = torch.max(out, dim=1, keepdim=False) labels = [y.item() for y in labels] np.asarray(labels) total += preds.size(0) correct += (preds.cpu().numpy() == labels).sum() # Test Accuracy test_acc = 100.0*correct/total print ('correct: ', correct, 'total: ', total) print ('Test Accuracy : %.2f' % test_acc) return test_acc def icarl_reduce_exemplar_sets(m): for y, P_y in enumerate(exemplar_sets): exemplar_sets[y] = P_y[:m] #Construct an exemplar set for image set def icarl_construct_exemplar_set(model, images, m, transform): model.eval() # Compute and cache features for each example features = [] with torch.no_grad(): for img in images: x = Variable(transform(Image.fromarray(img))).cuda() x=x.unsqueeze(0) feat = model.forward(x, rd=True).data.cpu().numpy() feat = feat / np.linalg.norm(feat) # Normalize features.append(feat[0]) features = np.array(features) class_mean = np.mean(features, axis=0) class_mean = class_mean / np.linalg.norm(class_mean) # Normalize exemplar_set = [] exemplar_features = [] # list of Variables of shape (feature_size,) exemplar_dist = [] for k in range(int(m)): S = np.sum(exemplar_features, axis=0) phi = features mu = class_mean mu_p = 1.0/(k+1) * (phi + S) mu_p = mu_p / np.linalg.norm(mu_p) dist = np.sqrt(np.sum((mu - mu_p) ** 2, axis=1)) idx = np.random.randint(0, features.shape[0]) exemplar_dist.append(dist[idx]) exemplar_set.append(images[idx]) exemplar_features.append(features[idx]) features[idx, :] = 0.0 # random exemplar selection exemplar_dist = np.array(exemplar_dist) exemplar_set = np.array(exemplar_set) ind = exemplar_dist.argsort() exemplar_set = exemplar_set[ind] exemplar_sets.append(np.array(exemplar_set)) print ('exemplar set shape: ', len(exemplar_set)) if __name__ == '__main__': args = parse_option() print (args) if not os.path.exists(os.path.join(args.output_root, "checkpoints/cifar/")): os.makedirs(os.path.join(args.output_root, "checkpoints/cifar/")) # parameters TOTAL_CLASS_NUM = 100 CLASS_NUM_IN_BATCH = args.start_classes TOTAL_CLASS_BATCH_NUM = TOTAL_CLASS_NUM // CLASS_NUM_IN_BATCH T = args.T K = args.K exemplar_sets = [] exemplar_means = [] compute_means = True normalize = transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) # Heavy-augmentation transform_aug = transforms.Compose([ transforms.ToTensor(), Cutout(n_holes=1, length=16), transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4, padding_mode='reflect'), CIFAR10Policy(), transforms.ToTensor(), normalize, ]) # default augmentation transform_ori = transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) # test-time augmentation transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) transform_train = TransformTwice(transform_ori, transform_ori) class_index = [i for i in range(0, TOTAL_CLASS_NUM)] np.random.seed(1993) np.random.shuffle(class_index) net = resnet32_cifar(num_classes=CLASS_NUM_IN_BATCH).cuda() model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print ('number of trainable parameters: ', params) old_net = copy.deepcopy(net) old_net.cuda() cls_list = [0] + [a for a in range(args.start_classes, 100, args.new_classes)] for i in cls_list: if i == args.start_classes: CLASS_NUM_IN_BATCH = args.new_classes print("==> Current Class: ", class_index[i:i+CLASS_NUM_IN_BATCH]) print('==> Building model..') if i == args.start_classes: net.change_output_dim(new_dim=i+CLASS_NUM_IN_BATCH) if i > args.start_classes: net.change_output_dim(new_dim=i+CLASS_NUM_IN_BATCH, second_iter=True) print("current net output dim:", net.get_output_dim()) # while using heavy augmentation if args.aug: if i==0: transform_train = TransformTwice(transform_aug, transform_aug) print ('.............augmentation.............') else: transform_train = TransformTwice(transform_ori, transform_ori) train_set = cifar100(root=args.data_root, train=True, classes=class_index[i:i+CLASS_NUM_IN_BATCH], download=True, transform=transform_train) trainLoader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4) train_classes = class_index[i:i+CLASS_NUM_IN_BATCH] test_classes = class_index[:i+CLASS_NUM_IN_BATCH] print (train_classes) print (test_classes) m = K // (i+CLASS_NUM_IN_BATCH) if i!=0: icarl_reduce_exemplar_sets(m) for y in range(i, i+CLASS_NUM_IN_BATCH): print ("Constructing exemplar set for class-%d..." %(class_index[y])) images = train_set.get_image_class(y) icarl_construct_exemplar_set(net, images, m, transform_test) print ("Done") # train and save model if args.resume and i==0: net.load_state_dict(torch.load(args.resume_path)) net.train() else: net.train() train(model=net, old_model=old_net, epoch=args.epochs, lr=args.lr, tempature=T, lamda=args.lamda, train_loader=trainLoader, use_sd=False, checkPoint=50) # print weight norm: task:2 onwards if i!=0: weight_norm(net) old_net = copy.deepcopy(net) old_net.cuda() # Do self-distillation if i == 0 and not args.resume: for sd in range(args.num_sd): train(model=net, old_model=old_net, epoch=args.epochs_sd, lr=args.lr_sd, tempature=T, lamda=args.lamda,train_loader=trainLoader, use_sd=True, checkPoint=50) old_net = copy.deepcopy(net) old_net.cuda() if args.save: save_path = os.path.join(args.output_root, "checkpoints/cifar/", args.exp_name) if not os.path.exists(save_path): os.makedirs(save_path) torch.save(net.state_dict(), os.path.join(save_path, 'checkpoint_' + str(i+CLASS_NUM_IN_BATCH) + '.pth')) # Evaluation on training and testing set transform_val = TransformTwice(transform_test, transform_test) test_acc = evaluate_net(model=net, transform=transform_val, train_classes=class_index[i:i+CLASS_NUM_IN_BATCH], test_classes=class_index[:i+CLASS_NUM_IN_BATCH]) avg_acc.append(test_acc) print (avg_acc) print ('Avg accuracy: ', sum(avg_acc)/len(avg_acc))