Skip to content

Commit

Permalink
add hybrid model(R50-ViT-B_16).
Browse files Browse the repository at this point in the history
  • Loading branch information
jeonsworld committed Nov 27, 2020
1 parent 878ebc5 commit c1102a1
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 16 deletions.
25 changes: 15 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ Vision Transformer achieve State-of-the-Art in image recognition task with stand

## Usage
### 1. Download Pre-trained model (Google's Official Checkpoint)
* [Available models](https://console.cloud.google.com/storage/vit_models/): ViT-B_16(**85.8M**), ViT-B_32(**87.5M**), ViT-L_16(**303.4M**), ViT-L_32(**305.5M**), ViT-H_14(**630.8M**)
* [Available models](https://console.cloud.google.com/storage/vit_models/): ViT-B_16(**85.8M**), R50+ViT-B_16(**97.96M**), ViT-B_32(**87.5M**), ViT-L_16(**303.4M**), ViT-L_32(**305.5M**), ViT-H_14(**630.8M**)
* imagenet21k pre-train models
* ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14
* imagenet21k pre-train + imagenet2012 fine-tuned models
* ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32
* Hybrid Model([Resnet50](https://github.com/google-research/big_transfer) + Transformer)
* R50-ViT-B_16
```
# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz
# imagenet21k pre-train + imagenet2012 fine-tuning
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz
```

### 2. Train Model
Expand All @@ -46,15 +49,17 @@ To verify that the converted model weight is correct, we simply compare it with
### imagenet-21k
* [**tensorboard**](https://tensorboard.dev/experiment/XvHOVNtMS02KOrmxOOJAEg/#scalars)

| model | dataset | resolution | acc(official) | acc(this repo) | time |
|:--------:|:---------:|:----------:|:-------------:|:--------------:|:-------:|
| ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9908 | 3h 13m |
| ViT-B_16 | CIFAR-10 | 384x384 | 0.9903 | 0.9906 | 12h 25m |
| ViT_B_16 | CIFAR-100 | 224x224 | - | 0.923 | 3h 9m |
| ViT_B_16 | CIFAR-100 | 384x384 | 0.9264 | 0.9228 | 12h 31m |
| ViT_L_32 | CIFAR-10 | 224x224 | - | 0.9903 | 2h 11m |
| ViT_L_32 | CIFAR-100 | 224x224 | - | 0.9273 | 2h 9m |
| ViT_H_14 | CIFAR-100 | 224x224 | - | WIP | |
| model | dataset | resolution | acc(official) | acc(this repo) | time |
|:------------:|:---------:|:----------:|:-------------:|:--------------:|:-------:|
| ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9908 | 3h 13m |
| ViT-B_16 | CIFAR-10 | 384x384 | 0.9903 | 0.9906 | 12h 25m |
| ViT_B_16 | CIFAR-100 | 224x224 | - | 0.923 | 3h 9m |
| ViT_B_16 | CIFAR-100 | 384x384 | 0.9264 | 0.9228 | 12h 31m |
| R50-ViT-B_16 | CIFAR-10 | 384x384 | 0.99 | WIP | |
| R50-ViT-B_16 | CIFAR-100 | 384x384 | 0.9231 | WIP | |
| ViT_L_32 | CIFAR-10 | 224x224 | - | 0.9903 | 2h 11m |
| ViT_L_32 | CIFAR-100 | 224x224 | - | 0.9273 | 2h 9m |
| ViT_H_14 | CIFAR-100 | 224x224 | - | WIP | |


### imagenet-21k + imagenet2012
Expand Down
11 changes: 11 additions & 0 deletions models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def get_b16_config():
return config


def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
del config.patches.size
config.patches.grid = (14, 14)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
return config


def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
Expand Down
39 changes: 34 additions & 5 deletions models/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import models.configs as configs

from .modeling_resnet import ResNetV2


logger = logging.getLogger(__name__)

Expand All @@ -33,9 +35,9 @@
MLP_NORM = "LayerNorm_2"


def np2th(weights):
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if weights.ndim == 4:
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)

Expand Down Expand Up @@ -124,10 +126,23 @@ class Embeddings(nn.Module):
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
img_size = _pair(img_size)
patch_size = _pair(config.patches["size"])
n_patches = (img_size[0]//patch_size[0]) * (img_size[1]//patch_size[1])

if config.patches.get("grid") is not None:
grid_size = config.patches["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
n_patches = grid_size[0] * grid_size[1]
self.hybrid = True
else:
patch_size = _pair(config.patches["size"])
self.hybrid = False
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

if self.hybrid:
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
width_factor=config.resnet.width_factor)
in_channels = self.hybrid_model.width * 16
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
Expand All @@ -141,6 +156,8 @@ def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)

if self.hybrid:
x = self.hybrid_model(x)
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
Expand Down Expand Up @@ -272,7 +289,7 @@ def load_from(self, weights):
self.head.weight.copy_(np2th(weights["head/kernel"]).t())
self.head.bias.copy_(np2th(weights["head/bias"]).t())

self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"]))
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
Expand Down Expand Up @@ -307,12 +324,24 @@ def load_from(self, weights):
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)

if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=bname, n_unit=uname)


CONFIGS = {
'ViT-B_16': configs.get_b16_config(),
'ViT-B_32': configs.get_b32_config(),
'ViT-L_16': configs.get_l16_config(),
'ViT-L_32': configs.get_l32_config(),
'ViT-H_14': configs.get_h14_config(),
'R50-ViT-B_16': configs.get_r50_b16_config(),
'testing': configs.get_testing(),
}
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def setup(args):
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
print(num_params)
return args, model


Expand Down Expand Up @@ -245,7 +246,9 @@ def main():
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10",
help="Which downstream task.")
parser.add_argument("--model_type", type=str, default="ViT-B_16",
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
"ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
default="ViT-B_16",
help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz",
help="Where to search for pretrained ViT models.")
Expand Down

0 comments on commit c1102a1

Please sign in to comment.