Skip to content

Commit

Permalink
4k working
Browse files Browse the repository at this point in the history
  • Loading branch information
sud0301 committed Aug 20, 2019
1 parent 62aaaaf commit 9d5ce3b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 40 deletions.
107 changes: 67 additions & 40 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from utils import progress_bar
from cutout import Cutout

import config
import model

#from data_loader import iCIFAR10
#from resnet import resnet18
from wrn import wrn
Expand All @@ -31,9 +28,10 @@
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.03, type=float, help='learning rate')
parser.add_argument('--lr-warm-up', action='store_true', help='increase lr slowly')
parser.add_argument('--warm-up-steps', default=20000, type=int, help='number of iterations for warmup')

parser.add_argument('--batch-size-lab', default=32, type=int, help='training batch size')
parser.add_argument('--batch-size-unlab', default=160, type=int, help='training batch size')
parser.add_argument('--batch-size-lab', default=64, type=int, help='training batch size')
parser.add_argument('--batch-size-unlab', default=320, type=int, help='training batch size')
parser.add_argument('--num-steps', default=100000, type=int, help='number of iterations')

parser.add_argument('--partial-data', default=0.5, type=float, help='partial data')
Expand All @@ -45,6 +43,7 @@
parser.add_argument('--cutout-size', default=16, type=float, help='size of the cutout window')

parser.add_argument('--autoaugment', action='store_true', help='use autoaugment augmentation')
parser.add_argument('--GCN', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -64,35 +63,32 @@ def __call__(self, inp):
# Data
print('==> Preparing data..')
transform_ori = transforms.Compose([
#transforms.RandomCrop(32, padding=4),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_aug = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), CIFAR10Policy(),
CIFAR10Policy(),
transforms.ToTensor(),
#transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
Cutout(n_holes=args.n_holes, length=args.cutout_size),
transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

#transform_aug = transform_ori
#transform_train_aug.transforms.append()

#if args.cutout:
# transform_aug.transforms.append(Cutout(n_holes=args.n_holes, length=args.cutout_size))

#if args.autoaugment:
# transform_aug.transforms.append(CIFAR10Policy())

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_train = TransformTwice(transform_ori, transform_aug)
Expand Down Expand Up @@ -124,22 +120,19 @@ def __call__(self, inp):
labels = np.array([trainset[i][1] for i in train_ids], dtype=np.int64)
for i in range(10):
mask[np.where(labels == i)[0][: int(4000 / 10)]] = True
# labeled_indices, unlabeled_indices = indices[mask], indices[~ mask]
labeled_indices, unlabeled_indices = train_ids[mask], train_ids
labeled_indices, unlabeled_indices = train_ids[mask], train_ids[~ mask]
#labeled_indices, unlabeled_indices = train_ids[mask], train_ids

train_sampler_lab = data.sampler.SubsetRandomSampler(labeled_indices)
train_sampler_unlab = data.sampler.SubsetRandomSampler(unlabeled_indices)

trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=8, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=8, pin_memory=True)

trainloader_val = data.DataLoader(labelset, batch_size=100, sampler=train_sampler_lab, num_workers=8, drop_last=False)
#trainloader_val = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=False, num_workers=2)
trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=16, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=16, pin_memory=True)

#testloader = data.DataLoader(trainset, batch_size=args.batch_size, sampler=test_sampler, num_workers=3, pin_memory=True)
trainloader_val = data.DataLoader(labelset, batch_size=100, sampler=train_sampler_lab, num_workers=16, drop_last=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)

#trainloader_lab_iter = iter(trainloader_lab)
#trainloader_unlab_iter = iter(trainloader_unlab)
Expand All @@ -150,6 +143,15 @@ def __call__(self, inp):
print('==> Building model..')
net = wrn().cuda()

'''
def weights_init(m):
if isinstance(m, nn.Conv2d):
#torch.nn.init.xavier_uniform_(m.weight.data)
#weights_shape = []
print ('m size: ', m.size())
n = m.size(0)*m.size(1)*m.size(3)
torch.nn.init.normal_(m.weight.data, mean=0.0, std=np.sqrt(2.0/n))
'''
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
Expand All @@ -168,6 +170,17 @@ def __call__(self, inp):
#optimizer = optim.Adam(net.parameters(), lr=args.lr, betas= (0.9, 0.999))
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps, eta_min=0.0001)


def global_contrast_normalize(X, scale=55., min_divisor=1e-8):
X = X.view(X.size(0), -1)
X = X - X.mean(dim=1, keepdim=True)

normalizers = torch.sqrt( torch.pow( X, 2).sum(dim=1, keepdim=True)) / scale
normalizers[normalizers < min_divisor] = 1.
X /= normalizers

return X.view(X.size(0),3,32,32)

def set_optimizer_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Expand All @@ -184,16 +197,16 @@ def train(epoch, trainloader_lab, trainloader_unlab, scheduler, optimizer):

trainloader_lab_iter = iter(trainloader_lab)
trainloader_unlab_iter = iter(trainloader_unlab)


#net.apply(weights_init)

for i_iter in range(args.num_steps):
net.train()
scheduler.step()
optimizer.zero_grad()

if args.lr_warm_up:
if i_iter < 10000:
warmup_lr = i_iter/10000* args.lr
if i_iter < args.warm_up_steps:
warmup_lr = i_iter/args.warm_up_steps* args.lr
optimizer = set_optimizer_lr(optimizer, warmup_lr)

if i_iter%1000==0:
Expand All @@ -206,8 +219,11 @@ def train(epoch, trainloader_lab, trainloader_unlab, scheduler, optimizer):
trainloader_lab_iter = iter(trainloader_lab)
batch_lab = next(trainloader_lab_iter)

(_, inputs_lab), targets_lab = batch_lab
(inputs_lab, _), targets_lab = batch_lab
inputs_lab, targets_lab = inputs_lab.to(device), targets_lab.to(device)

if args.GCN:
inputs_lab = global_contrast_normalize( inputs_lab )

outputs_lab = net(inputs_lab)
loss_lab = criterion(outputs_lab, targets_lab)
Expand All @@ -221,13 +237,21 @@ def train(epoch, trainloader_lab, trainloader_unlab, scheduler, optimizer):

(inputs_unlab, inputs_unlab_aug), _ = batch_unlab
inputs_unlab, inputs_unlab_aug = inputs_unlab.cuda(), inputs_unlab_aug.cuda()

if args.GCN:
inputs_unlab = global_contrast_normalize( inputs_unlab )
inputs_unlab_aug = global_contrast_normalize( inputs_unlab_aug )

outputs_unlab = net(inputs_unlab)
outputs_unlab_aug = net(inputs_unlab_aug)
#print (targets)
#loss_unlab = nn.KLDivLoss()(F.log_softmax(outputs_unlab), F.softmax(outputs_unlab_aug))
loss_unlab = nn.KLDivLoss()(F.log_softmax(outputs_unlab_aug, dim=1), F.softmax(outputs_unlab, dim=1))

#print ('outputs unlab size(): ', outputs_unlab.size())
#loss_unlab = nn.KLDivLoss()(F.log_softmax(outputs_unlab_aug, dim=1), F.softmax(outputs_unlab, dim=1).detach(), reduction='batchmean')

loss_unlab = torch.nn.functional.kl_div(
torch.nn.functional.log_softmax(outputs_unlab_aug, dim=1),
torch.nn.functional.softmax(outputs_unlab, dim=1).detach(), reduction='batchmean')

#loss_kldiv = F.kl_div(F.log_softmax(outputs_unlab_aug, dim=1), F.softmax(outputs_unlab, dim=1), reduction='none') # loss for unsupervised
#loss_kldiv = torch.sum(loss_kldiv, dim=1)
Expand All @@ -238,6 +262,7 @@ def train(epoch, trainloader_lab, trainloader_unlab, scheduler, optimizer):

loss.backward()
optimizer.step()
scheduler.step()

train_loss += loss.item()
train_loss_lab += loss_lab.item()
Expand All @@ -246,7 +271,7 @@ def train(epoch, trainloader_lab, trainloader_unlab, scheduler, optimizer):
#progress_bar(i_iter, args.num_steps, 'Loss: %.6f | Loss_lab: %.6f'
#% (loss.item(), loss_lab.item()))
progress_bar(i_iter, args.num_steps, 'Loss: %.6f | Loss_lab: %.6f | Loss_unlab: %.6f'
% (loss.item(), loss_lab.item(), loss_unlab.item()))
% (train_loss/1000.0, train_loss_lab/1000.0, train_loss_unlab/1000.0))

if i_iter%1000==0:
train_loss /= 1000
Expand All @@ -268,7 +293,7 @@ def val():
correct = 0
total = 0
U_all = []
fp = open('results_with_val','a')
fp = open('results_with_val.txt','a')
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(trainloader_val):
inputs, targets = inputs.to(device), targets.to(device)
Expand Down Expand Up @@ -297,10 +322,12 @@ def test(epoch, i_iter, loss, loss_lab, loss_unlab):
correct = 0
total = 0
U_all = []
fp = open('results_semi_w_cutout_in_AA.txt','a')
fp = open('results_semi_64_320_100k_w_flip_and_crop_20k_warm_up_sep_masks_wo_GCN_last_norm_again.txt','a')
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
#if args.GCN:
#inputs = global_contrast_normalize( inputs )
outputs = net(inputs)

probs = F.softmax(outputs, dim=1)
Expand Down
2 changes: 2 additions & 0 deletions wrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(self, depth=28, num_classes=10, widen_factor=2, dropRate=0.0):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = 1.0/(num_classes)**(0.5)
m.weight.data.uniform_(-1*n, n)
m.bias.data.zero_()

def forward(self, x):
Expand Down

0 comments on commit 9d5ce3b

Please sign in to comment.