Skip to content

Commit

Permalink
cifar5 test added
Browse files Browse the repository at this point in the history
  • Loading branch information
sud0301 committed Sep 14, 2019
1 parent 58537f8 commit c97a96e
Show file tree
Hide file tree
Showing 8 changed files with 879 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ checkpoint/
data/
results/
models/
scripts/

results*.txt
*.pkl
__pycache__
Expand Down
1 change: 1 addition & 0 deletions augment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

File renamed without changes.
File renamed without changes.
94 changes: 59 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,53 @@
import numpy as np

#from models.senet import *
from utils import progress_bar
from cutout import Cutout
#from utils import progress_bar
from augment.cutout import Cutout

#from resnet import resnet18
from wrn import wrn
#from autoaugment_extra_only_color import CIFAR10Policy
from autoaugment_extra import CIFAR10Policy
from augment.autoaugment_extra import CIFAR10Policy

DATASET = 'CIFAR10'
SEED = 0
SPLIT_ID = None

parser = argparse.ArgumentParser(description='PyTorch SSL CIFAR10 UDA Training')
parser.add_argument('--lr', default=0.03, type=float, help='learning rate')
parser.add_argument('--softmax-temp', default=-1, type=float, help='softmax temperature controlling')
parser.add_argument('--confidence-mask', default=-1, type=float, help='Confidence value for masking')
parse = argparse.ArgumentParser(description='PyTorch SSL CIFAR10 UDA Training')
parse.add_argument('--dataset', type=str, default=DATASET, help='dataset')
parse.add_argument('--num-classes', default=10, type=int, help='number of classes')

parser.add_argument('--num-labeled', default=1000, type=int, help='number of labeled_samples')
parse.add_argument('--lr', default=0.03, type=float, help='learning rate')
parse.add_argument('--softmax-temp', default=-1, type=float, help='softmax temperature controlling')
parse.add_argument('--confidence-mask', default=-1, type=float, help='Confidence value for masking')

parser.add_argument('--batch-size-lab', default=64, type=int, help='training batch size')
parser.add_argument('--batch-size-unlab', default=320, type=int, help='training batch size')
parser.add_argument('--num-steps', default=100000, type=int, help='number of iterations')
parser.add_argument('--lr-warm-up', action='store_true', help='increase lr slowly')
parser.add_argument('--warm-up-steps', default=20000, type=int, help='number of iterations for warmup')
parser.add_argument('--num-cycles', default=4, type=int, help='number of sgdr cycles')
parse.add_argument('--num-labeled', default=1000, type=int, help='number of labeled_samples')

parser.add_argument("--split-id", type=str, default=SPLIT_ID, help="restore partial id list")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--verbose', action='store_true', help='show progress bar')
parser.add_argument('--seed', default=SEED, type=int, help='seed index')
parse.add_argument('--batch-size-lab', default=64, type=int, help='training batch size')
parse.add_argument('--batch-size-unlab', default=320, type=int, help='training batch size')
parse.add_argument('--num-steps', default=100000, type=int, help='number of iterations')
parse.add_argument('--lr-warm-up', action='store_true', help='increase lr slowly')
parse.add_argument('--warm-up-steps', default=20000, type=int, help='number of iterations for warmup')
parse.add_argument('--num-cycles', default=1, type=int, help='number of sgdr cycles')

parse.add_argument('--split-id', type=str, default=SPLIT_ID, help='restore partial id list')
parse.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parse.add_argument('--verbose', action='store_true', help='show progress bar')
parse.add_argument('--seed', default=SEED, type=int, help='seed index')

# Supervised or Semi-supervised
parser.add_argument('--lab-only', action='store_true', help='if using only labeled samples')
parse.add_argument('--lab-only', action='store_true', help='if using only labeled samples')

# Augmenatations
parser.add_argument('--cutout', action='store_true', help='use cutout augmentation')
parser.add_argument('--n-holes', default=1, type=float, help='number of holes for cutout')
parser.add_argument('--cutout-size', default=16, type=float, help='size of the cutout window')
parser.add_argument('--autoaugment', action='store_true', help='use autoaugment augmentation')
parse.add_argument('--cutout', action='store_true', help='use cutout augmentation')
parse.add_argument('--n-holes', default=1, type=float, help='number of holes for cutout')
parse.add_argument('--cutout-size', default=16, type=float, help='size of the cutout window')
parse.add_argument('--autoaugment', action='store_true', help='use autoaugment augmentation')

args = parser.parse_args()
args = parse.parse_args()

CHECKPOINT_DIR = './results/labels_' + str(args.num_labeled) + '_batch_lab_' + str(args.batch_size_lab) + '_batch_unlab_' + str(args.batch_size_unlab) + '_steps_' + str(args.num_steps) +'_warmup_' + str(args.warm_up_steps) + '_softmax_temp_' + str(args.softmax_temp) + '_conf_mask_' + str(args.confidence_mask) + '_SEED_' + str(args.seed)
#CHECKPOINT_DIR = './results/dataset_' + str(args.dataset) + '_labels_' + str(args.num_labeled) + '_batch_lab_' + str(args.batch_size_lab) + '_batch_unlab_' + str(args.batch_size_unlab) + '_steps_' + str(args.num_steps) +'_warmup_' + str(args.warm_up_steps) + '_softmax_temp_' + str(args.softmax_temp) + '_conf_mask_' + str(args.confidence_mask) + '_SEED_' + str(args.seed)
CHECKPOINT_DIR = './results/dataset_' + str(args.dataset) + '_labels_' + str(args.num_labeled) + '_batch_lab_' + str(args.batch_size_lab) + '_batch_unlab_' + str(args.batch_size_unlab) + '_steps_' + str(args.num_steps) +'_warmup_' + str(args.warm_up_steps) + '_softmax_temp_' + str(args.softmax_temp) + '_conf_mask_' + str(args.confidence_mask) + '_SEED_' + str(args.seed)

if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
Expand Down Expand Up @@ -105,13 +110,19 @@ def __call__(self, inp):
])

transform_train = TransformTwice(transform_ori, transform_aug)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
labelset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
if args.dataset == 'CIFAR10':
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
labelset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
if args.dataset == 'CIFAR100':
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
labelset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_test)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
#trainset_aug = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
#trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)

train_dataset_size = len(trainset)
test_dataset_size = len(testset)

#if args.partial_id is not None:
#train_ids = pickle.load(open(args.split_id, 'rb'))
Expand All @@ -127,32 +138,45 @@ def __call__(self, inp):
print('loading train ids from {}'.format(args.split_id))
else:
train_ids = np.arange(train_dataset_size)
test_ids = np.arange(test_dataset_size)
np.random.shuffle(train_ids)
pickle.dump(train_ids, open(os.path.join(CHECKPOINT_DIR, 'train_id_' + str(args.seed) + '.pkl'), 'wb'))

mask = np.zeros(train_ids.shape[0], dtype=np.bool)
labels = np.array([trainset[i][1] for i in train_ids], dtype=np.int64)
for i in range(10):
mask[np.where(labels == i)[0][: int(args.num_labeled / 10)]] = True
'''
for i in range(args.num_classes):
mask[np.where(labels == i)[0][: int(args.num_labeled / args.num_classes)]] = True
labeled_indices, unlabeled_indices = train_ids[mask], train_ids[~ mask]
#labeled_indices, unlabeled_indices = train_ids[mask], train_ids
'''
mask_ = np.zeros(train_ids.shape[0], dtype=np.bool)
mask_test = np.zeros(test_ids.shape[0], dtype=np.bool)
labels_test = np.array([testset[i][1] for i in test_ids], dtype=np.int64)
for i in range(10):
mask[np.where(labels == i)[0][: int(args.num_labeled / args.num_classes)]] = True
mask_[np.where(labels == i)[0][int(args.num_labeled / args.num_classes): ]] = True
mask_test[np.where(labels_test == i)[0]] = True
labeled_indices, unlabeled_indices = train_ids[mask], train_ids[mask_]
test_indices = test_ids[mask_test]

train_sampler_lab = data.sampler.SubsetRandomSampler(labeled_indices)
train_sampler_unlab = data.sampler.SubsetRandomSampler(unlabeled_indices)
test_sampler = data.sampler.SubsetRandomSampler(test_indices)

trainloader_lab = data.DataLoader(trainset, batch_size=args.batch_size_lab, sampler=train_sampler_lab, num_workers=16, drop_last=True)
trainloader_unlab = data.DataLoader(trainset, batch_size=args.batch_size_unlab, sampler=train_sampler_unlab, num_workers=16, pin_memory=True)

trainloader_val = data.DataLoader(labelset, batch_size=100, sampler=train_sampler_lab, num_workers=16, drop_last=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)
testloader = data.DataLoader(testset, batch_size=100, sampler=test_sampler, num_workers=16)
#testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = wrn().cuda()
net = wrn(num_classes=args.num_classes).cuda()

if device == 'cuda':
net = torch.nn.DataParallel(net)
Expand Down Expand Up @@ -251,7 +275,7 @@ def train(cycle, trainloader_lab, trainloader_unlab, scheduler, optimizer):
largest_prob, _ = unlab_prob.max(1)
mask = (largest_prob>args.confidence_mask).float().detach()
loss_unlab = loss_unlab*mask

loss_unlab = torch.mean(loss_unlab)

else:
Expand Down Expand Up @@ -290,7 +314,7 @@ def train(cycle, trainloader_lab, trainloader_unlab, scheduler, optimizer):
train_loss_unlab /= 1000

test(cycle, i_iter, train_loss, train_loss_lab, train_loss_unlab)
val()
#val()

train_loss = 0
train_loss_lab = 0
Expand Down
Loading

0 comments on commit c97a96e

Please sign in to comment.