Skip to content

Commit

Permalink
Applying SUR for ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
manogna-s committed Jun 15, 2021
1 parent 101d92b commit dcdbf42
Show file tree
Hide file tree
Showing 14 changed files with 1,088 additions and 367 deletions.
293 changes: 293 additions & 0 deletions doc/README.md

Large diffs are not rendered by default.

File renamed without changes.
255 changes: 255 additions & 0 deletions doc/dataset_conversion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# Dataset download and conversion

This file contains instructions to download the individual datasets used by
Meta-Dataset, and convert them into a common format (one TFRecord file per
class). See [an overview](../README.md#downloading-and-converting-datasets) for
more context.

## ilsvrc_2012

1. Download `ilsvrc2012_img_train.tar`, from the
[ILSVRC2012 website](http://www.image-net.org/challenges/LSVRC/2012/index)
2. Extract it into `ILSVRC2012_img_train/`, which should contain 1000 files,
named `n????????.tar` (expected time: \~30 minutes)
3. Extract each of `ILSVRC2012_img_train/n????????.tar` in its own directory
(expected time: \~30 minutes), for instance:

```bash
for FILE in *.tar;
do
mkdir ${FILE/.tar/};
cd ${FILE/.tar/};
tar xvf ../$FILE;
cd ..;
done
```

4. Download the following two files into `ILSVRC2012_img_train/`:

- http://www.image-net.org/archive/wordnet.is_a.txt
- http://www.image-net.org/archive/words.txt

5. Launch the conversion script (Use `--dataset=ilsvrc_2012_v2` for the
training only MetaDataset-v2 version):

```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=ilsvrc_2012 \
--ilsvrc_2012_data_root=$DATASRC/ILSVRC2012_img_train \
--splits_root=$SPLITS \
--records_root=$RECORDS
```

6. Expect the conversion to take 4 to 12 hours, depending on the filesystem's
latency and bandwidth.
7. Find the following outputs in `$RECORDS/ilsvrc_2012/`:
- 1000 tfrecords files named `[0-999].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
- `num_leaf_images.json`
## omniglot
1. Download
[`images_background.zip`](https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip)
and
[`images_evaluation.zip`](https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip)
2. Extract them into the same `omniglot/` directory
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=omniglot \
--omniglot_data_root=$DATASRC/omniglot \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take a few seconds.
5. Find the following outputs in `$RECORDS/omniglot/`:
- 1623 tfrecords files named `[0-1622].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## aircraft
1. Download
[`fgvc-aircraft-2013b.tar.gz`](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz)
2. Extract it into `fgvc-aircraft-2013b`
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=aircraft \
--aircraft_data_root=$DATASRC/fgvc-aircraft-2013b \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take 5 to 10 minutes.
5. Find the following outputs in `$RECORDS/aircraft/`:
- 100 tfrecords files named `[0-99].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## cu_birds
1. Download
[`CUB_200_2011.tgz`](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz)
2. Extract it into `CUB_200_2011/` (and `attributes.txt`)
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=cu_birds \
--cu_birds_data_root=$DATASRC/CUB_200_2011 \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take around one minute.
5. Find the following outputs in `$RECORDS/cu_birds/`:
- 200 tfrecords files named `[0-199].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## dtd
1. Download
[`dtd-r1.0.1.tar.gz`](https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz)
2. Extract it into `dtd/`
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=dtd \
--dtd_data_root=$DATASRC/dtd \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take a few seconds.
5. Find the following outputs in `$RECORDS/dtd/`:
- 47 tfrecords files named `[0-46].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## quickdraw
1. Download all 345 `.npy` files hosted on
[Google Cloud](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap)
- You can use
[`gsutil`](https://cloud.google.com/storage/docs/gsutil_install#install)
to download them to `quickdraw/`:
```bash
gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy $DATASRC/quickdraw
```
2. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=quickdraw \
--quickdraw_data_root=$DATASRC/quickdraw \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
3. Expect the conversion to take 3 to 4 hours.
4. Find the following outputs in `$RECORDS/quickdraw/`:
- 345 tfrecords files named `[0-344].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## fungi
1. Download
[`fungi_train_val.tgz`](https://labs.gbif.org/fgvcx/2018/fungi_train_val.tgz)
and
[`train_val_annotations.tgz`](https://labs.gbif.org/fgvcx/2018/train_val_annotations.tgz)
2. Extract them into the same `fungi/` directory. It should contain one
`images/` directory, as well as `train.json` and `val.json`.
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=fungi \
--fungi_data_root=$DATASRC/fungi \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take 5 to 15 minutes.
4. Find the following outputs in `$RECORDS/fungi/`:
- 1394 tfrecords files named `[0-1393].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## vgg_flower
1. Download
[`102flowers.tgz`](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz)
and
[`imagelabels.mat`](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat)
2. Extract `102flowers.tgz`, it will create a `jpg/` sub-directory
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=vgg_flower \
--vgg_flower_data_root=$DATASRC/vgg_flower \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take about one minute.
5. Find the following outputs in `$RECORDS/vgg_flower/`:
- 102 tfrecords files named `[0-101].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## traffic_sign
1. Download
[`GTSRB_Final_Training_Images.zip`](https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip)
If the link happens to be broken, browse the GTSRB dataset [website](http://benchmark.ini.rub.de) for more information.
2. Extract it in `$DATASRC`, it will create a `GTSRB/` sub-directory
3. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=traffic_sign \
--traffic_sign_data_root=$DATASRC/GTSRB \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
4. Expect the conversion to take about one minute.
4. Find the following outputs in `$RECORDS/traffic_sign/`:
- 43 tfrecords files named `[0-42].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## mscoco
1. Download the 2017 train images and annotations from http://cocodataset.org/:
- You can use
[`gsutil`](https://cloud.google.com/storage/docs/gsutil_install#install)
to download them to `mscoco/`:
```bash
cd $DATASRC/mscoco/ mkdir train2017
gsutil -m rsync gs://images.cocodataset.org/train2017 train2017
gsutil -m cp gs://images.cocodataset.org/annotations/annotations_trainval2017.zip
unzip annotations_trainval2017.zip
```
- Otherwise, you can download
[`train2017.zip`](http://images.cocodataset.org/zips/train2017.zip) and
[`annotations_trainval2017.zip`](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
and extract them into `mscoco/`.
2. Launch the conversion script:
```bash
python -m meta_dataset.dataset_conversion.convert_datasets_to_records \
--dataset=mscoco \
--mscoco_data_root=$DATASRC/mscoco \
--splits_root=$SPLITS \
--records_root=$RECORDS
```
3. Expect the conversion to take about 4 hours.
4. Find the following outputs in `$RECORDS/mscoco/`:
- 80 tfrecords files named `[0-79].tfrecords`
- `dataset_spec.json` (see [note 1](#notes))
## Notes
1. A [reference version](
https://github.com/google-research/meta-dataset/tree/main//meta_datasetdataset_conversion/dataset_specs)
of each of the `dataset_spec.json` files is part of this repository. You can
compare them with the version generated by the conversion process for
troubleshooting.
88 changes: 88 additions & 0 deletions eval_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import gin
import numpy as np
from tqdm import tqdm
from models.losses import prototype_loss
from cdfsl_dataset.meta_dataset_reader import MetaDatasetEpisodeReader
from train import get_model
import argparse
from tabulate import tabulate
import tensorflow as tf




@gin.configurable()
def eval_metadataset(args, img_size):
args, model = get_model(args, training=False)
config_file = f'cdfsl_dataset/configs/meta_dataset_{img_size}x{img_size}.gin'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

extractors=dict()
extractors['imagenet'] = model
extractors['cifar'] = model

trainsets = "omniglot".split(' ')
valsets = "omniglot".split(' ')
testsets = "mnist".split(' ')
print('train domains:', trainsets)
print('test domains:', testsets)

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
tf.compat.v1.disable_eager_execution()
val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets, config_file=config_file)
test_loader = MetaDatasetEpisodeReader('test', trainsets, valsets, testsets, config_file=config_file)

N_TASKS = 10
accs_names = ['ViT']
var_accs = dict()
with tf.compat.v1.Session(config=config) as session:
# go over each test domain
for dataset in testsets:
print(dataset)
var_accs[dataset] = {name: [] for name in accs_names}
for i in tqdm(range(N_TASKS)):
with torch.no_grad():
sample = test_loader.get_test_task(session, dataset)
context_features = model.forward_features(sample['context_images'])
target_features = model.forward_features(sample['target_images'])
context_labels = sample['context_labels'].to(device)
target_labels = sample['target_labels'].to(device)
_, stats_dict_ViT, _ = prototype_loss(context_features, context_labels,
target_features, target_labels)
vit_acc = stats_dict_ViT['acc']
# print(f'Accuracy of ViT test task {i}:{vit_acc}')

var_accs[dataset]['ViT'].append(stats_dict_ViT['acc'])

# Print SUR results table
rows = []
for dataset_name in testsets:
row = [dataset_name]
for model_name in accs_names:
acc = np.array(var_accs[dataset_name][model_name]) * 100
mean_acc = acc.mean()
conf = (1.96 * acc.std()) / np.sqrt(len(acc))
row.append(f"{mean_acc:0.2f} +- {conf:0.2f}")
rows.append(row)

table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f")
print(table)
print("\n\n")
return


def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--model_config", type=str, default="vit_configs/cifar_84.gin",
help="Where to search for pretrained ViT models.")
args = parser.parse_args()
gin.parse_config_file(args.model_config)
eval_metadataset(args)


if __name__ == '__main__':
main()
Loading

0 comments on commit dcdbf42

Please sign in to comment.