Skip to content

Commit

Permalink
Update annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
xmlyqing00 committed Nov 8, 2020
1 parent 6daddab commit 8db9e4b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 55 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ It is designed for semi-supervised video object segmentation (VOS) task.

![](assets/pipeline.png)

**Paper corrections:** Our feature map generated by the encoders has 1024 channels and 1/8 of the original image size.
**Paper corrections:** Our feature map generated by the encoders has 1024 channels and **1/16** of the original image size.


## 1. Requirements
Expand Down Expand Up @@ -71,7 +71,7 @@ By default, the segmentation results will be saved in `./output`.

### Pre-training on Static Images

1. Download the following the datasets. You don't have to all, COCO is the largest one.
1. Download the following the datasets (COCO is the largest one). You don't have to download all, our pretrain codes skip datasets that don't exist by default.
2. Run `unify_pretrain_dataset.py` to convert them into a uniform format (followed DAVIS).
```bash
python3 unify_pretrain_dataset.py --name NAME --src /path/to/dataset/dir/ --dst /path/to/output
Expand Down
6 changes: 3 additions & 3 deletions dataset/PreTrain_DS.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, root, output_size, dataset_file='./assets/pretrain.txt', clip
dataset_list.append(dataset_name)
self.img_list += img_list
self.mask_list += mask_list
print(f'\t{dataset_name}: {len(img_list)} imgs.')
else:
print(
f'PreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.')
print(f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.')
else:
print(f'PreTrain dataset {dataset_name} doesn\'t exist. Skip.')
print(f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.')

print(myutils.gct(), f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.')

Expand Down
2 changes: 1 addition & 1 deletion model/AFB_URR.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def segment(self, frame, fb_global):

if self.training:
uncertainty = myutils.calc_uncertainty(NF.softmax(score, dim=1))
uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1])
uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1]) # [B,1,H,W]
uncertainty = uncertainty.mean()
else:
uncertainty = None
Expand Down
55 changes: 6 additions & 49 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,58 +48,12 @@ def get_args():
return parser.parse_args()


def run_pretrain(model, dataloader, criterion, optimizer):
stats = myutils.AvgMeter()

progress_bar = tqdm(dataloader, desc='Pre Train')
for iter_idx, sample in enumerate(progress_bar):
frames, masks, obj_n, info = sample

if obj_n.item() == 1:
continue

frames, masks = frames[0].to(device), masks[0].to(device)

# with torch.autograd.detect_anomaly():

k4, v4 = model.memorize(frames[0:1], masks[0:1])
scores = model.segment(frames[1:], k4, v4)
label = torch.argmax(masks[1:], dim=1).long()

optimizer.zero_grad()
loss = criterion(scores, label)
loss.backward()
optimizer.step()

stats.update(loss.item())
progress_bar.set_postfix(loss=('%.6f' % stats.avg))
progress_bar.update(1)

# Save tmp model
if iter_idx == 40000 or iter_idx == 80000:
checkpoint = {
'epoch': iter_idx,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': stats.avg,
'seed': -1,
}

cp_path = f'tmp/cp_{iter_idx}.pth'
torch.save(checkpoint, cp_path)
print('Save to', cp_path)

progress_bar.close()

return stats.avg


def run_maintrain(model, dataloader, criterion, optimizer):
def train_model(model, dataloader, criterion, optimizer, desc):

stats = myutils.AvgMeter()
uncertainty_stats = myutils.AvgMeter()

progress_bar = tqdm(dataloader, desc='Main Train')
progress_bar = tqdm(dataloader, desc=desc)
for iter_idx, sample in enumerate(progress_bar):
frames, masks, obj_n, info = sample

Expand Down Expand Up @@ -140,10 +94,13 @@ def main():

if args.level == 0:
dataset = PreTrain_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
desc = 'Pre Train'
elif args.level == 1:
dataset = DAVIS_Train_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
desc = 'Train DAVIS17'
elif args.level == 2:
dataset = YouTube_Train_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
desc = 'Train YV18'
else:
raise ValueError(f'{args.level} is unknown.')

Expand Down Expand Up @@ -203,7 +160,7 @@ def main():
print('')
print(myutils.gct(), f'Epoch: {epoch} lr: {lr}')

loss = run_maintrain(model, dataloader, criterion, optimizer)
loss = train_model(model, dataloader, criterion, optimizer, desc)
if args.log:

checkpoint = {
Expand Down

0 comments on commit 8db9e4b

Please sign in to comment.