Skip to content

Commit

Permalink
uda imagenet first working
Browse files Browse the repository at this point in the history
  • Loading branch information
sud0301 committed Sep 20, 2019
1 parent 89061fe commit 3eddb49
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

parse.add_argument('--num-labeled', default=1000, type=int, help='number of labeled_samples')

parse.add_argument('--batch-size-lab', default=64, type=int, help='training batch size')
parse.add_argument('--batch-size-lab', default=32, type=int, help='training batch size')
parse.add_argument('--batch-size-unlab', default=320, type=int, help='training batch size')
parse.add_argument('--num-steps', default=100000, type=int, help='number of iterations')
parse.add_argument('--lr-warm-up', action='store_true', help='increase lr slowly')
Expand Down
53 changes: 30 additions & 23 deletions main_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,29 +203,36 @@ def main():
np.random.shuffle(train_ids)
#pickle.dump(train_ids, open(os.path.join(args.checkpoint, 'train_id_' + str(args.seed) + '.pkl'), 'wb'))

if args.new_splits:
train_ids = train_ids[:4000]
mask = np.zeros(train_ids.shape[0], dtype=np.bool)
labels = np.array([trainset[i][1] for i in train_ids], dtype=np.int64)
num_labeled = int(args.percent_labeled*len(train_ids)/100.0)
print ('Num labeled: ', num_labeled)
for i in range(args.num_classes):
mask[np.where(labels == i)[0][: int(num_labeled / args.num_classes)]] = True
labeled_indices, unlabeled_indices = train_ids[mask], train_ids[~ mask]

pickle.dump(labeled_indices, open(os.path.join(args.checkpoint, 'labeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'wb'))
pickle.dump(unlabeled_indices, open(os.path.join(args.checkpoint, 'unlabeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'wb'))
else:
labeled_indices = pickle.load(open(os.path.join(args.checkpoint, 'labeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'rb'))
unlabeled_indices = pickle.load(open(os.path.join(args.checkpoint, 'unlabeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'rb'))
if args.lab_only:
train_sampler_lab = data.sampler.SubsetRandomSampler(train_ids)
train_sampler_unlab = data.sampler.SubsetRandomSampler(train_ids)

trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=16, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=16, pin_memory=True)

print ('Labeled indices: ', len(labeled_indices), ' Unlabeled indices: ', len(unlabeled_indices))
else:
if args.new_splits:
mask = np.zeros(train_ids.shape[0], dtype=np.bool)
labels = np.array([trainset[i][1] for i in train_ids], dtype=np.int64)
num_labeled = int(args.percent_labeled*len(train_ids)/100.0)
print ('Num labeled: ', num_labeled)
for i in range(args.num_classes):
mask[np.where(labels == i)[0][: int(num_labeled / args.num_classes)]] = True
labeled_indices, unlabeled_indices = train_ids[mask], train_ids[~ mask]

pickle.dump(labeled_indices, open(os.path.join(args.checkpoint, 'labeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'wb'))
pickle.dump(unlabeled_indices, open(os.path.join(args.checkpoint, 'unlabeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'wb'))
else:
labeled_indices = pickle.load(open(os.path.join(args.checkpoint, 'labeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'rb'))
unlabeled_indices = pickle.load(open(os.path.join(args.checkpoint, 'unlabeled_idxs_' + str(args.percent_labeled) + '_' + str(args.seed) + '.pkl'), 'rb'))

train_sampler_lab = data.sampler.SubsetRandomSampler(labeled_indices)
train_sampler_unlab = data.sampler.SubsetRandomSampler(unlabeled_indices)
print ('Labeled indices: ', len(labeled_indices), ' Unlabeled indices: ', len(unlabeled_indices))

train_sampler_lab = data.sampler.SubsetRandomSampler(labeled_indices)
train_sampler_unlab = data.sampler.SubsetRandomSampler(unlabeled_indices)

trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=16, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=16, pin_memory=True)
trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=16, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=16, pin_memory=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)

Expand Down Expand Up @@ -350,8 +357,8 @@ def train(trainloader_lab, trainloader_unlab, testloader, net, scheduler, optimi
loss_unlab = torch.nn.functional.kl_div(
torch.nn.functional.log_softmax(outputs_unlab_aug, dim=1),
torch.nn.functional.softmax(outputs_unlab, dim=1).detach(), reduction='batchmean')
train_loss_unlab.update(loss_unlab.item())

train_loss_unlab.update(loss_unlab.item())
loss = loss_lab + loss_unlab
train_loss.update(loss.item())

Expand All @@ -370,7 +377,7 @@ def train(trainloader_lab, trainloader_unlab, testloader, net, scheduler, optimi
progress_bar(i_iter, args.num_steps, 'Loss: %.6f | Loss_lab: %.6f | Loss_unlab: %.6f'
% (train_loss.avg, train_loss_lab.avg, train_loss_unlab.avg))

if i_iter%1000==0:
if i_iter%5000==0 and i_iter>0:
test_loss, test_acc = test(net, testloader, criterion, optimizer, i_iter)
logger.append([state['lr'], train_loss.avg, test_loss, test_acc])

Expand Down

0 comments on commit 3eddb49

Please sign in to comment.