Skip to content

Commit

Permalink
Update Pretrain codes and add dataset.txt file.
Browse files Browse the repository at this point in the history
  • Loading branch information
xmlyqing00 committed Nov 5, 2020
1 parent 412e788 commit c1e54e3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
Binary file modified assets/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 14 additions & 8 deletions dataset/PreTrain_DS.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class PreTrain_DS(data.Dataset):

def __init__(self, root, output_size, dataset_file='dataset.txt', clip_n=3, max_obj_n=11):
def __init__(self, root, output_size, dataset_file='dataset/pretrain.txt', clip_n=3, max_obj_n=11):
self.root = root
self.clip_n = clip_n
self.output_size = output_size
Expand All @@ -21,23 +21,29 @@ def __init__(self, root, output_size, dataset_file='dataset.txt', clip_n=3, max_
self.img_list = list()
self.mask_list = list()

dataset_path = os.path.join(root, dataset_file)
dataset_list = list()
with open(os.path.join(dataset_path), 'r') as lines:
with open(os.path.join(dataset_file), 'r') as lines:
for line in lines:
dataset_name = line.strip()
dataset_list.append(dataset_name)

img_dir = os.path.join(root, 'JPEGImages', dataset_name)
mask_dir = os.path.join(root, 'Annotations', dataset_name)

img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + sorted(glob(os.path.join(img_dir, '*.png')))
mask_list = sorted(glob(os.path.join(mask_dir, '*.png')))

assert len(img_list) == len(mask_list)

self.img_list += img_list
self.mask_list += mask_list
if len(img_list) > 0:
if len(img_list) == len(mask_list):
dataset_list.append(dataset_name)
self.img_list += img_list
self.mask_list += mask_list
else:
print(
f'PreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.')
else:
print(f'PreTrain dataset {dataset_name} doesn\'t exist. Skip.')

print(myutils.gct(), f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.')

self.random_horizontal_flip = mytrans.RandomHorizontalFlip(0.3)
self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
Expand Down
5 changes: 5 additions & 0 deletions dataset/pretrain.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
COCO
ECSSD
MSRA10K
PASCAL-S
PASCALVOC2012

0 comments on commit c1e54e3

Please sign in to comment.