Skip to content

Commit

Permalink
replace RandomResizedCrop
Browse files Browse the repository at this point in the history
  • Loading branch information
jeonsworld committed Nov 4, 2020
1 parent 1e24870 commit b651319
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@


def get_loader(args):
img_preprocess = transforms.Compose([
transform_train = transforms.Compose([
transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
transform_test = transforms.Compose([
transforms.Resize((args.img_size, args.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
Expand All @@ -18,21 +23,21 @@ def get_loader(args):
trainset = datasets.CIFAR10(root="./data",
train=True,
download=True,
transform=img_preprocess)
transform=transform_train)
testset = datasets.CIFAR10(root="./data",
train=False,
download=True,
transform=img_preprocess) if args.local_rank in [-1, 0] else None
transform=transform_test) if args.local_rank in [-1, 0] else None

else:
trainset = datasets.CIFAR100(root="./data",
train=True,
download=True,
transform=img_preprocess)
transform=transform_train)
testset = datasets.CIFAR100(root="./data",
train=False,
download=True,
transform=img_preprocess) if args.local_rank in [-1, 0] else None
transform=transform_test) if args.local_rank in [-1, 0] else None

train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
test_sampler = SequentialSampler(testset)
Expand Down

0 comments on commit b651319

Please sign in to comment.