From 55b3d908c7a9ede50a4961f14ed2148a968ed86c Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 19 Dec 2019 07:34:43 +0000 Subject: [PATCH 1/9] Updated original tf files --- .../original_tf/efficientnet_builder.py | 154 +++++- .../original_tf/efficientnet_model.py | 485 ++++++++++++++---- .../original_tf/preprocessing.py | 61 ++- .../convert_tf_to_pt/original_tf/utils.py | 271 ++++++++-- 4 files changed, 801 insertions(+), 170 deletions(-) diff --git a/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_builder.py b/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_builder.py index 1b80bbe..ff384b1 100644 --- a/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_builder.py +++ b/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_builder.py @@ -18,11 +18,18 @@ from __future__ import division from __future__ import print_function +import functools import os import re -import tensorflow as tf +from absl import logging +import numpy as np +import six +import tensorflow.compat.v1 as tf import efficientnet_model +import utils +MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] +STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] def efficientnet_params(model_name): @@ -37,6 +44,8 @@ def efficientnet_params(model_name): 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), } return params_dict[model_name] @@ -46,7 +55,10 @@ class BlockDecoder(object): def _decode_block_string(self, block_string): """Gets a block through a string notation of arguments.""" - assert isinstance(block_string, str) + if six.PY2: + assert isinstance(block_string, (str, unicode)) + else: + assert isinstance(block_string, str) ops = block_string.split('_') options = {} for op in ops: @@ -66,7 +78,12 @@ def _decode_block_string(self, block_string): expand_ratio=int(options['e']), id_skip=('noskip' not in block_string), se_ratio=float(options['se']) if 'se' in options else None, - strides=[int(options['s'][0]), int(options['s'][1])]) + strides=[int(options['s'][0]), + int(options['s'][1])], + conv_type=int(options['c']) if 'c' in options else 0, + fused_conv=int(options['f']) if 'f' in options else 0, + super_pixel=int(options['p']) if 'p' in options else 0, + condconv=('cc' in block_string)) def _encode_block_string(self, block): """Encodes a block to a string.""" @@ -76,12 +93,17 @@ def _encode_block_string(self, block): 's%d%d' % (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, 'i%d' % block.input_filters, - 'o%d' % block.output_filters + 'o%d' % block.output_filters, + 'c%d' % block.conv_type, + 'f%d' % block.fused_conv, + 'p%d' % block.super_pixel, ] if block.se_ratio > 0 and block.se_ratio <= 1: args.append('se%s' % block.se_ratio) - if block.id_skip is False: + if block.id_skip is False: # pylint: disable=g-bool-id-comparison args.append('noskip') + if block.condconv: + args.append('cc') return '_'.join(args) def decode(self, string_list): @@ -113,30 +135,70 @@ def encode(self, blocks_args): return block_strings +def swish(features, use_native=True, use_hard=False): + """Computes the Swish activation function. + + We provide three alternnatives: + - Native tf.nn.swish, use less memory during training than composable swish. + - Quantization friendly hard swish. + - A composable swish, equivalant to tf.nn.swish, but more general for + finetuning and TF-Hub. + + Args: + features: A `Tensor` representing preactivation values. + use_native: Whether to use the native swish from tf.nn that uses a custom + gradient to reduce memory usage, or to use customized swish that uses + default TensorFlow gradient computation. + use_hard: Whether to use quantization-friendly hard swish. + + Returns: + The activation value. + """ + if use_native and use_hard: + raise ValueError('Cannot specify both use_native and use_hard.') + + if use_native: + return tf.nn.swish(features) + + if use_hard: + return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.) + + features = tf.convert_to_tensor(features, name='features') + return features * tf.nn.sigmoid(features) + + +_DEFAULT_BLOCKS_ARGS = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', +] + + def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, - drop_connect_rate=0.2): + survival_prob=0.8): """Creates a efficientnet model.""" - blocks_args = [ - 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', - 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', - 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', - 'r1_k3_s11_e6_i192_o320_se0.25', - ] global_params = efficientnet_model.GlobalParams( + blocks_args=_DEFAULT_BLOCKS_ARGS, batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, dropout_rate=dropout_rate, - drop_connect_rate=drop_connect_rate, + survival_prob=survival_prob, data_format='channels_last', num_classes=1000, width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, depth_divisor=8, - min_depth=None) - decoder = BlockDecoder() - return decoder.decode(blocks_args), global_params + min_depth=None, + relu_fn=tf.nn.swish, + # The default is TPU-specific batch norm. + # The alternative is tf.layers.BatchNormalization. + batch_norm=utils.TpuBatchNormalization, # TPU-specific requirement. + use_se=True, + clip_projection_output=False) + return global_params def get_model_params(model_name, override_params): @@ -144,7 +206,7 @@ def get_model_params(model_name, override_params): if model_name.startswith('efficientnet'): width_coefficient, depth_coefficient, _, dropout_rate = ( efficientnet_params(model_name)) - blocks_args, global_params = efficientnet( + global_params = efficientnet( width_coefficient, depth_coefficient, dropout_rate) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) @@ -154,8 +216,10 @@ def get_model_params(model_name, override_params): # in global_params. global_params = global_params._replace(**override_params) - tf.logging.info('global_params= %s', global_params) - tf.logging.info('blocks_args= %s', blocks_args) + decoder = BlockDecoder() + blocks_args = decoder.decode(global_params.blocks_args) + + logging.info('global_params= %s', global_params) return blocks_args, global_params @@ -163,7 +227,10 @@ def build_model(images, model_name, training, override_params=None, - model_dir=None): + model_dir=None, + fine_tuning=False, + features_only=False, + pooled_features_only=False): """A helper functiion to creates a model and returns predicted logits. Args: @@ -173,6 +240,11 @@ def build_model(images, override_params: A dictionary of params for overriding. Fields must exist in efficientnet_model.GlobalParams. model_dir: string, optional model dir for saving configs. + fine_tuning: boolean, whether the model is used for finetuning. + features_only: build the base feature network only (excluding final + 1x1 conv layer, global pooling, dropout and fc head). + pooled_features_only: build the base network for features extraction (after + 1x1 conv layer and global pooling, but before dropout and fc head). Returns: logits: the logits tensor of classes. @@ -183,23 +255,45 @@ def build_model(images, When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) + assert not (features_only and pooled_features_only) + + # For backward compatibility. + if override_params and override_params.get('drop_connect_rate', None): + override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] + + if not training or fine_tuning: + if not override_params: + override_params = {} + override_params['batch_norm'] = utils.BatchNormalization + if fine_tuning: + override_params['relu_fn'] = functools.partial(swish, use_native=False) blocks_args, global_params = get_model_params(model_name, override_params) if model_dir: param_file = os.path.join(model_dir, 'model_params.txt') if not tf.gfile.Exists(param_file): + if not tf.gfile.Exists(model_dir): + tf.gfile.MakeDirs(model_dir) with tf.gfile.GFile(param_file, 'w') as f: - tf.logging.info('writing to %s' % param_file) + logging.info('writing to %s', param_file) f.write('model_name= %s\n\n' % model_name) f.write('global_params= %s\n\n' % str(global_params)) f.write('blocks_args= %s\n\n' % str(blocks_args)) with tf.variable_scope(model_name): model = efficientnet_model.Model(blocks_args, global_params) - logits = model(images, training=training) - - logits = tf.identity(logits, 'logits') - return logits, model.endpoints + outputs = model( + images, + training=training, + features_only=features_only, + pooled_features_only=pooled_features_only) + if features_only: + outputs = tf.identity(outputs, 'features') + elif pooled_features_only: + outputs = tf.identity(outputs, 'pooled_features') + else: + outputs = tf.identity(outputs, 'logits') + return outputs, model.endpoints def build_model_base(images, model_name, training, override_params=None): @@ -207,10 +301,10 @@ def build_model_base(images, model_name, training, override_params=None): Args: images: input images tensor. - model_name: string, the model name of a pre-defined MnasNet. + model_name: string, the predefined model name. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in - mnasnet_model.GlobalParams. + efficientnet_model.GlobalParams. Returns: features: global pool features. @@ -221,11 +315,15 @@ def build_model_base(images, model_name, training, override_params=None): When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) + # For backward compatibility. + if override_params and override_params.get('drop_connect_rate', None): + override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] + blocks_args, global_params = get_model_params(model_name, override_params) with tf.variable_scope(model_name): model = efficientnet_model.Model(blocks_args, global_params) features = model(images, training=training, features_only=True) - features = tf.identity(features, 'global_pool') + features = tf.identity(features, 'features') return features, model.endpoints diff --git a/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_model.py b/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_model.py index 2b312d3..6bc827e 100644 --- a/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_model.py +++ b/tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_model.py @@ -24,30 +24,31 @@ from __future__ import print_function import collections +import functools import math + +from absl import logging import numpy as np import six -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf +from six.moves import xrange +import tensorflow.compat.v1 as tf -#from efficientnet_pytorch import utils -from original_tf import utils +import utils +# from condconv import condconv_layers GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', - 'num_classes', 'width_coefficient', 'depth_coefficient', - 'depth_divisor', 'min_depth', 'drop_connect_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', + 'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se', + 'local_pooling', 'condconv_num_experts', 'clip_projection_output', + 'blocks_args' ]) GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) -# batchnorm = tf.layers.BatchNormalization -batchnorm = utils.TpuBatchNormalization # TPU-specific requirement. -relu_fn = tf.nn.swish - - BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', - 'expand_ratio', 'id_skip', 'strides', 'se_ratio' + 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv', + 'super_pixel', 'condconv' ]) # defaults will be a public argument for namedtuple in Python 3.7 # https://docs.python.org/3/library/collections.html#collections.namedtuple @@ -60,7 +61,7 @@ def conv_kernel_initializer(shape, dtype=None, partition_info=None): The main difference with tf.variance_scaling_initializer is that tf.variance_scaling_initializer uses a truncated normal with an uncorrected standard deviation, whereas here we use a normal distribution. Similarly, - tf.contrib.layers.variance_scaling_initializer uses a truncated normal with + tf.initializers.variance_scaling uses a truncated normal with a corrected standard deviation. Args: @@ -99,6 +100,41 @@ def dense_kernel_initializer(shape, dtype=None, partition_info=None): return tf.random_uniform(shape, -init_range, init_range, dtype=dtype) +def superpixel_kernel_initializer(shape, dtype='float32', partition_info=None): + """Initializes superpixel kernels. + + This is inspired by space-to-depth transformation that is mathematically + equivalent before and after the transformation. But we do the space-to-depth + via a convolution. Moreover, we make the layer trainable instead of direct + transform, we can initialization it this way so that the model can learn not + to do anything but keep it mathematically equivalent, when improving + performance. + + + Args: + shape: shape of variable + dtype: dtype of variable + partition_info: unused + + Returns: + an initialization for the variable + """ + del partition_info + # use input depth to make superpixel kernel. + depth = shape[-2] + filters = np.zeros([2, 2, depth, 4 * depth], dtype=dtype) + i = np.arange(2) + j = np.arange(2) + k = np.arange(depth) + mesh = np.array(np.meshgrid(i, j, k)).T.reshape(-1, 3).T + filters[ + mesh[0], + mesh[1], + mesh[2], + 4 * mesh[2] + 2 * mesh[0] + mesh[1]] = 1 + return filters + + def round_filters(filters, global_params): """Round number of filters based on depth multiplier.""" orig_f = filters @@ -114,7 +150,7 @@ def round_filters(filters, global_params): # Make sure that round down does not go down by more than 10%. if new_filters < 0.9 * filters: new_filters += divisor - tf.logging.info('round_filter input={} output={}'.format(orig_f, new_filters)) + logging.info('round_filter input=%s output=%s', orig_f, new_filters) return int(new_filters) @@ -126,12 +162,10 @@ def round_repeats(repeats, global_params): return int(math.ceil(multiplier * repeats)) -class MBConvBlock(object): - """A class of MBConv: Mobile Inveretd Residual Bottleneck. +class MBConvBlock(tf.keras.layers.Layer): + """A class of MBConv: Mobile Inverted Residual Bottleneck. Attributes: - has_se: boolean. Whether the block contains a Squeeze and Excitation layer - inside. endpoints: dict. A list of internal tensors. """ @@ -142,20 +176,38 @@ def __init__(self, block_args, global_params): block_args: BlockArgs, arguments to create a Block. global_params: GlobalParams, a set of global parameters. """ + super(MBConvBlock, self).__init__() self._block_args = block_args self._batch_norm_momentum = global_params.batch_norm_momentum self._batch_norm_epsilon = global_params.batch_norm_epsilon - if global_params.data_format == 'channels_first': + self._batch_norm = global_params.batch_norm + self._condconv_num_experts = global_params.condconv_num_experts + self._data_format = global_params.data_format + if self._data_format == 'channels_first': self._channel_axis = 1 self._spatial_dims = [2, 3] else: self._channel_axis = -1 self._spatial_dims = [1, 2] - self.has_se = (self._block_args.se_ratio is not None) and ( - self._block_args.se_ratio > 0) and (self._block_args.se_ratio <= 1) + + self._relu_fn = global_params.relu_fn or tf.nn.swish + self._has_se = ( + global_params.use_se and self._block_args.se_ratio is not None and + 0 < self._block_args.se_ratio <= 1) + + self._clip_projection_output = global_params.clip_projection_output self.endpoints = None + self.conv_cls = tf.layers.Conv2D + self.depthwise_conv_cls = utils.DepthwiseConv2D + if self._block_args.condconv: + self.conv_cls = functools.partial( + condconv_layers.CondConv2D, num_experts=self._condconv_num_experts) + self.depthwise_conv_cls = functools.partial( + condconv_layers.DepthwiseCondConv2D, + num_experts=self._condconv_num_experts) + # Builds the block accordings to arguments. self._build() @@ -164,35 +216,70 @@ def block_args(self): def _build(self): """Builds block according to the arguments.""" - filters = self._block_args.input_filters * self._block_args.expand_ratio - if self._block_args.expand_ratio != 1: - # Expansion phase: - self._expand_conv = tf.layers.Conv2D( - filters, - kernel_size=[1, 1], - strides=[1, 1], + if self._block_args.super_pixel == 1: + self._superpixel = tf.layers.Conv2D( + self._block_args.input_filters, + kernel_size=[2, 2], + strides=[2, 2], kernel_initializer=conv_kernel_initializer, padding='same', + data_format=self._data_format, use_bias=False) - self._bn0 = batchnorm( + self._bnsp = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) + if self._block_args.condconv: + # Add the example-dependent routing function + self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( + data_format=self._data_format) + self._routing_fn = tf.layers.Dense( + self._condconv_num_experts, activation=tf.nn.sigmoid) + + filters = self._block_args.input_filters * self._block_args.expand_ratio kernel_size = self._block_args.kernel_size - # Depth-wise convolution phase: - self._depthwise_conv = utils.DepthwiseConv2D( - [kernel_size, kernel_size], + + # Fused expansion phase. Called if using fused convolutions. + self._fused_conv = self.conv_cls( + filters=filters, + kernel_size=[kernel_size, kernel_size], + strides=self._block_args.strides, + kernel_initializer=conv_kernel_initializer, + padding='same', + data_format=self._data_format, + use_bias=False) + + # Expansion phase. Called if not using fused convolutions and expansion + # phase is necessary. + self._expand_conv = self.conv_cls( + filters=filters, + kernel_size=[1, 1], + strides=[1, 1], + kernel_initializer=conv_kernel_initializer, + padding='same', + data_format=self._data_format, + use_bias=False) + self._bn0 = self._batch_norm( + axis=self._channel_axis, + momentum=self._batch_norm_momentum, + epsilon=self._batch_norm_epsilon) + + # Depth-wise convolution phase. Called if not using fused convolutions. + self._depthwise_conv = self.depthwise_conv_cls( + kernel_size=[kernel_size, kernel_size], strides=self._block_args.strides, depthwise_initializer=conv_kernel_initializer, padding='same', + data_format=self._data_format, use_bias=False) - self._bn1 = batchnorm( + + self._bn1 = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) - if self.has_se: + if self._has_se: num_reduced_filters = max( 1, int(self._block_args.input_filters * self._block_args.se_ratio)) # Squeeze and Excitation layer. @@ -202,6 +289,7 @@ def _build(self): strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', + data_format=self._data_format, use_bias=True) self._se_expand = tf.layers.Conv2D( filters, @@ -209,18 +297,20 @@ def _build(self): strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', + data_format=self._data_format, use_bias=True) - # Output phase: + # Output phase. filters = self._block_args.output_filters - self._project_conv = tf.layers.Conv2D( - filters, + self._project_conv = self.conv_cls( + filters=filters, kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', + data_format=self._data_format, use_bias=False) - self._bn2 = batchnorm( + self._bn2 = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) @@ -235,48 +325,162 @@ def _call_se(self, input_tensor): A output tensor, which should have the same shape as input. """ se_tensor = tf.reduce_mean(input_tensor, self._spatial_dims, keepdims=True) - se_tensor = self._se_expand(relu_fn(self._se_reduce(se_tensor))) - tf.logging.info('Built Squeeze and Excitation with tensor shape: %s' % - (se_tensor.shape)) + se_tensor = self._se_expand(self._relu_fn(self._se_reduce(se_tensor))) + logging.info('Built Squeeze and Excitation with tensor shape: %s', + (se_tensor.shape)) return tf.sigmoid(se_tensor) * input_tensor - def call(self, inputs, training=True, drop_connect_rate=None): + def call(self, inputs, training=True, survival_prob=None): """Implementation of call(). Args: inputs: the inputs tensor. training: boolean, whether the model is constructed for training. - drop_connect_rate: float, between 0 to 1, drop connect rate. + survival_prob: float, between 0 to 1, drop connect rate. Returns: A output tensor. """ - tf.logging.info('Block input: %s shape: %s' % (inputs.name, inputs.shape)) - if self._block_args.expand_ratio != 1: - x = relu_fn(self._bn0(self._expand_conv(inputs), training=training)) + logging.info('Block input: %s shape: %s', inputs.name, inputs.shape) + logging.info('Block input depth: %s output depth: %s', + self._block_args.input_filters, + self._block_args.output_filters) + + x = inputs + + fused_conv_fn = self._fused_conv + expand_conv_fn = self._expand_conv + depthwise_conv_fn = self._depthwise_conv + project_conv_fn = self._project_conv + + if self._block_args.condconv: + pooled_inputs = self._avg_pooling(inputs) + routing_weights = self._routing_fn(pooled_inputs) + # Capture routing weights as additional input to CondConv layers + fused_conv_fn = functools.partial( + self._fused_conv, routing_weights=routing_weights) + expand_conv_fn = functools.partial( + self._expand_conv, routing_weights=routing_weights) + depthwise_conv_fn = functools.partial( + self._depthwise_conv, routing_weights=routing_weights) + project_conv_fn = functools.partial( + self._project_conv, routing_weights=routing_weights) + + # creates conv 2x2 kernel + if self._block_args.super_pixel == 1: + with tf.variable_scope('super_pixel'): + x = self._relu_fn( + self._bnsp(self._superpixel(x), training=training)) + logging.info( + 'Block start with SuperPixel: %s shape: %s', x.name, x.shape) + + if self._block_args.fused_conv: + # If use fused mbconv, skip expansion and use regular conv. + x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training)) + logging.info('Conv2D: %s shape: %s', x.name, x.shape) else: - x = inputs - tf.logging.info('Expand: %s shape: %s' % (x.name, x.shape)) + # Otherwise, first apply expansion and then apply depthwise conv. + if self._block_args.expand_ratio != 1: + x = self._relu_fn(self._bn0(expand_conv_fn(x), training=training)) + logging.info('Expand: %s shape: %s', x.name, x.shape) - x = relu_fn(self._bn1(self._depthwise_conv(x), training=training)) - tf.logging.info('DWConv: %s shape: %s' % (x.name, x.shape)) + x = self._relu_fn(self._bn1(depthwise_conv_fn(x), training=training)) + logging.info('DWConv: %s shape: %s', x.name, x.shape) - if self.has_se: + if self._has_se: with tf.variable_scope('se'): x = self._call_se(x) self.endpoints = {'expansion_output': x} - x = self._bn2(self._project_conv(x), training=training) + x = self._bn2(project_conv_fn(x), training=training) + # Add identity so that quantization-aware training can insert quantization + # ops correctly. + x = tf.identity(x) + if self._clip_projection_output: + x = tf.clip_by_value(x, -6, 6) + if self._block_args.id_skip: + if all( + s == 1 for s in self._block_args.strides + ) and self._block_args.input_filters == self._block_args.output_filters: + # Apply only if skip connection presents. + if survival_prob: + x = utils.drop_connect(x, training, survival_prob) + x = tf.add(x, inputs) + logging.info('Project: %s shape: %s', x.name, x.shape) + return x + + +class MBConvBlockWithoutDepthwise(MBConvBlock): + """MBConv-like block without depthwise convolution and squeeze-and-excite.""" + + def _build(self): + """Builds block according to the arguments.""" + filters = self._block_args.input_filters * self._block_args.expand_ratio + if self._block_args.expand_ratio != 1: + # Expansion phase: + self._expand_conv = tf.layers.Conv2D( + filters, + kernel_size=[3, 3], + strides=[1, 1], + kernel_initializer=conv_kernel_initializer, + padding='same', + use_bias=False) + self._bn0 = self._batch_norm( + axis=self._channel_axis, + momentum=self._batch_norm_momentum, + epsilon=self._batch_norm_epsilon) + + # Output phase: + filters = self._block_args.output_filters + self._project_conv = tf.layers.Conv2D( + filters, + kernel_size=[1, 1], + strides=self._block_args.strides, + kernel_initializer=conv_kernel_initializer, + padding='same', + use_bias=False) + self._bn1 = self._batch_norm( + axis=self._channel_axis, + momentum=self._batch_norm_momentum, + epsilon=self._batch_norm_epsilon) + + def call(self, inputs, training=True, survival_prob=None): + """Implementation of call(). + + Args: + inputs: the inputs tensor. + training: boolean, whether the model is constructed for training. + survival_prob: float, between 0 to 1, drop connect rate. + + Returns: + A output tensor. + """ + logging.info('Block input: %s shape: %s', inputs.name, inputs.shape) + if self._block_args.expand_ratio != 1: + x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training)) + else: + x = inputs + logging.info('Expand: %s shape: %s', x.name, x.shape) + + self.endpoints = {'expansion_output': x} + + x = self._bn1(self._project_conv(x), training=training) + # Add identity so that quantization-aware training can insert quantization + # ops correctly. + x = tf.identity(x) + if self._clip_projection_output: + x = tf.clip_by_value(x, -6, 6) + if self._block_args.id_skip: if all( s == 1 for s in self._block_args.strides ) and self._block_args.input_filters == self._block_args.output_filters: - # only apply drop_connect if skip presents. - if drop_connect_rate: - x = utils.drop_connect(x, training, drop_connect_rate) + # Apply only if skip connection presents. + if survival_prob: + x = utils.drop_connect(x, training, survival_prob) x = tf.add(x, inputs) - tf.logging.info('Project: %s shape: %s' % (x.name, x.shape)) + logging.info('Project: %s shape: %s', x.name, x.shape) return x @@ -301,39 +505,28 @@ def __init__(self, blocks_args=None, global_params=None): raise ValueError('blocks_args should be a list.') self._global_params = global_params self._blocks_args = blocks_args + self._relu_fn = global_params.relu_fn or tf.nn.swish + self._batch_norm = global_params.batch_norm + self.endpoints = None + self._build() + def _get_conv_block(self, conv_type): + conv_block_map = {0: MBConvBlock, 1: MBConvBlockWithoutDepthwise} + return conv_block_map[conv_type] + def _build(self): """Builds a model.""" self._blocks = [] - # Builds blocks. - for block_args in self._blocks_args: - assert block_args.num_repeat > 0 - # Update block input and output filters based on depth multiplier. - block_args = block_args._replace( - input_filters=round_filters(block_args.input_filters, - self._global_params), - output_filters=round_filters(block_args.output_filters, - self._global_params), - num_repeat=round_repeats(block_args.num_repeat, self._global_params)) - - # The first block needs to take care of stride and filter size increase. - self._blocks.append(MBConvBlock(block_args, self._global_params)) - if block_args.num_repeat > 1: - # pylint: disable=protected-access - block_args = block_args._replace( - input_filters=block_args.output_filters, strides=[1, 1]) - # pylint: enable=protected-access - for _ in xrange(block_args.num_repeat - 1): - self._blocks.append(MBConvBlock(block_args, self._global_params)) - batch_norm_momentum = self._global_params.batch_norm_momentum batch_norm_epsilon = self._global_params.batch_norm_epsilon if self._global_params.data_format == 'channels_first': channel_axis = 1 + self._spatial_dims = [2, 3] else: channel_axis = -1 + self._spatial_dims = [1, 2] # Stem part. self._conv_stem = tf.layers.Conv2D( @@ -342,12 +535,62 @@ def _build(self): strides=[2, 2], kernel_initializer=conv_kernel_initializer, padding='same', + data_format=self._global_params.data_format, use_bias=False) - self._bn0 = batchnorm( + self._bn0 = self._batch_norm( axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon) + # Builds blocks. + for block_args in self._blocks_args: + assert block_args.num_repeat > 0 + assert block_args.super_pixel in [0, 1, 2] + # Update block input and output filters based on depth multiplier. + input_filters = round_filters(block_args.input_filters, + self._global_params) + output_filters = round_filters(block_args.output_filters, + self._global_params) + kernel_size = block_args.kernel_size + block_args = block_args._replace( + input_filters=input_filters, + output_filters=output_filters, + num_repeat=round_repeats(block_args.num_repeat, self._global_params)) + + # The first block needs to take care of stride and filter size increase. + conv_block = self._get_conv_block(block_args.conv_type) + if not block_args.super_pixel: # no super_pixel at all + self._blocks.append(conv_block(block_args, self._global_params)) + else: + # if superpixel, adjust filters, kernels, and strides. + depth_factor = int(4 / block_args.strides[0] / block_args.strides[1]) + block_args = block_args._replace( + input_filters=block_args.input_filters * depth_factor, + output_filters=block_args.output_filters * depth_factor, + kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1 + else block_args.kernel_size)) + # if the first block has stride-2 and super_pixel trandformation + if (block_args.strides[0] == 2 and block_args.strides[1] == 2): + block_args = block_args._replace(strides=[1, 1]) + self._blocks.append(conv_block(block_args, self._global_params)) + block_args = block_args._replace( # sp stops at stride-2 + super_pixel=0, + input_filters=input_filters, + output_filters=output_filters, + kernel_size=kernel_size) + elif block_args.super_pixel == 1: + self._blocks.append(conv_block(block_args, self._global_params)) + block_args = block_args._replace(super_pixel=2) + else: + self._blocks.append(conv_block(block_args, self._global_params)) + if block_args.num_repeat > 1: # rest of blocks with the same block_arg + # pylint: disable=protected-access + block_args = block_args._replace( + input_filters=block_args.output_filters, strides=[1, 1]) + # pylint: enable=protected-access + for _ in xrange(block_args.num_repeat - 1): + self._blocks.append(conv_block(block_args, self._global_params)) + # Head part. self._conv_head = tf.layers.Conv2D( filters=round_filters(1280, self._global_params), @@ -356,57 +599,75 @@ def _build(self): kernel_initializer=conv_kernel_initializer, padding='same', use_bias=False) - self._bn1 = batchnorm( + self._bn1 = self._batch_norm( axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon) self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( data_format=self._global_params.data_format) - self._fc = tf.layers.Dense( - self._global_params.num_classes, - kernel_initializer=dense_kernel_initializer) + if self._global_params.num_classes: + self._fc = tf.layers.Dense( + self._global_params.num_classes, + kernel_initializer=dense_kernel_initializer) + else: + self._fc = None if self._global_params.dropout_rate > 0: self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate) else: self._dropout = None - def call(self, inputs, training=True, features_only=None): + def call(self, + inputs, + training=True, + features_only=None, + pooled_features_only=False): """Implementation of call(). Args: inputs: input tensors. training: boolean, whether the model is constructed for training. features_only: build the base feature network only. + pooled_features_only: build the base network for features extraction + (after 1x1 conv layer and global pooling, but before dropout and fc + head). Returns: output tensors. """ outputs = None self.endpoints = {} + reduction_idx = 0 # Calls Stem layers with tf.variable_scope('stem'): - outputs = relu_fn( + outputs = self._relu_fn( self._bn0(self._conv_stem(inputs), training=training)) - tf.logging.info('Built stem layers with output shape: %s' % outputs.shape) + logging.info('Built stem layers with output shape: %s', outputs.shape) self.endpoints['stem'] = outputs # Calls blocks. - reduction_idx = 0 for idx, block in enumerate(self._blocks): - is_reduction = False - if ((idx == len(self._blocks) - 1) or - self._blocks[idx + 1].block_args().strides[0] > 1): + is_reduction = False # reduction flag for blocks after the stem layer + # If the first block has super-pixel (space-to-depth) layer, then stem is + # the first reduction point. + if (block.block_args().super_pixel == 1 and idx == 0): + reduction_idx += 1 + self.endpoints['reduction_%s' % reduction_idx] = outputs + + elif ((idx == len(self._blocks) - 1) or + self._blocks[idx + 1].block_args().strides[0] > 1): is_reduction = True reduction_idx += 1 with tf.variable_scope('blocks_%s' % idx): - drop_rate = self._global_params.drop_connect_rate - if drop_rate: - drop_rate *= float(idx) / len(self._blocks) - tf.logging.info('block_%s drop_connect_rate: %s' % (idx, drop_rate)) - outputs = block.call(outputs, training=training) + survival_prob = self._global_params.survival_prob + if survival_prob: + drop_rate = 1.0 - survival_prob + survival_prob = 1.0 - drop_rate * float(idx) / len(self._blocks) + logging.info('block_%s survival_prob: %s', idx, survival_prob) + outputs = block.call( + outputs, training=training, survival_prob=survival_prob) self.endpoints['block_%s' % idx] = outputs if is_reduction: self.endpoints['reduction_%s' % reduction_idx] = outputs @@ -415,16 +676,38 @@ def call(self, inputs, training=True, features_only=None): self.endpoints['block_%s/%s' % (idx, k)] = v if is_reduction: self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v - self.endpoints['global_pool'] = outputs + self.endpoints['features'] = outputs if not features_only: # Calls final layers and returns logits. with tf.variable_scope('head'): - outputs = relu_fn( + outputs = self._relu_fn( self._bn1(self._conv_head(outputs), training=training)) - outputs = self._avg_pooling(outputs) - if self._dropout: - outputs = self._dropout(outputs, training=training) - outputs = self._fc(outputs) - self.endpoints['head'] = outputs + self.endpoints['head_1x1'] = outputs + + if self._global_params.local_pooling: + shape = outputs.get_shape().as_list() + kernel_size = [ + 1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1] + outputs = tf.nn.avg_pool( + outputs, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID') + self.endpoints['pooled_features'] = outputs + if not pooled_features_only: + if self._dropout: + outputs = self._dropout(outputs, training=training) + self.endpoints['global_pool'] = outputs + if self._fc: + outputs = tf.squeeze(outputs, self._spatial_dims) + outputs = self._fc(outputs) + self.endpoints['head'] = outputs + else: + outputs = self._avg_pooling(outputs) + self.endpoints['pooled_features'] = outputs + if not pooled_features_only: + if self._dropout: + outputs = self._dropout(outputs, training=training) + self.endpoints['global_pool'] = outputs + if self._fc: + outputs = self._fc(outputs) + self.endpoints['head'] = outputs return outputs diff --git a/tf_to_pytorch/convert_tf_to_pt/original_tf/preprocessing.py b/tf_to_pytorch/convert_tf_to_pt/original_tf/preprocessing.py index c19006a..e7af8ab 100644 --- a/tf_to_pytorch/convert_tf_to_pt/original_tf/preprocessing.py +++ b/tf_to_pytorch/convert_tf_to_pt/original_tf/preprocessing.py @@ -17,7 +17,10 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from absl import logging + +import tensorflow.compat.v1 as tf + IMAGE_SIZE = 224 CROP_PADDING = 32 @@ -122,7 +125,6 @@ def _decode_and_center_crop(image_bytes, image_size): padded_center_crop_size, padded_center_crop_size]) image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = tf.image.resize_bicubic([image], [image_size, image_size])[0] - return image @@ -132,13 +134,24 @@ def _flip(image): return image -def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, + augment_name=None, + randaug_num_layers=None, randaug_magnitude=None): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. use_bfloat16: `bool` for whether to use bfloat16. image_size: image size. + augment_name: `string` that is the name of the augmentation method + to apply to the image. `autoaugment` if AutoAugment is to be used or + `randaugment` if RandAugment is to be used. If the value is `None` no + augmentation method will be applied applied. See autoaugment.py for more + details. + randaug_num_layers: 'int', if RandAug is used, what should the number of + layers be. See autoaugment.py for detailed description. + randaug_magnitude: 'int', if RandAug is used, what should the magnitude + be. See autoaugment.py for detailed description. Returns: A preprocessed image `Tensor`. @@ -146,8 +159,32 @@ def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): image = _decode_and_random_crop(image_bytes, image_size) image = _flip(image) image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + + if augment_name: + try: + import autoaugment # pylint: disable=g-import-not-at-top + except ImportError as e: + logging.exception('Autoaugment is not supported in TF 2.x.') + raise e + + logging.info('Apply AutoAugment policy %s', augment_name) + input_image_type = image.dtype + image = tf.clip_by_value(image, 0.0, 255.0) + image = tf.cast(image, dtype=tf.uint8) + + if augment_name == 'autoaugment': + logging.info('Apply AutoAugment policy %s', augment_name) + image = autoaugment.distort_image_with_autoaugment(image, 'v0') + elif augment_name == 'randaugment': + image = autoaugment.distort_image_with_randaugment( + image, randaug_num_layers, randaug_magnitude) + else: + raise ValueError('Invalid value for augment_name: %s' % (augment_name)) + + image = tf.cast(image, dtype=input_image_type) return image @@ -172,7 +209,10 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): def preprocess_image(image_bytes, is_training=False, use_bfloat16=False, - image_size=IMAGE_SIZE): + image_size=IMAGE_SIZE, + augment_name=None, + randaug_num_layers=None, + randaug_magnitude=None): """Preprocesses the given image. Args: @@ -180,11 +220,22 @@ def preprocess_image(image_bytes, is_training: `bool` for whether the preprocessing is for training. use_bfloat16: `bool` for whether to use bfloat16. image_size: image size. + augment_name: `string` that is the name of the augmentation method + to apply to the image. `autoaugment` if AutoAugment is to be used or + `randaugment` if RandAugment is to be used. If the value is `None` no + augmentation method will be applied applied. See autoaugment.py for more + details. + randaug_num_layers: 'int', if RandAug is used, what should the number of + layers be. See autoaugment.py for detailed description. + randaug_magnitude: 'int', if RandAug is used, what should the magnitude + be. See autoaugment.py for detailed description. Returns: A preprocessed image `Tensor` with value range of [0, 255]. """ if is_training: - return preprocess_for_train(image_bytes, use_bfloat16, image_size) + return preprocess_for_train( + image_bytes, use_bfloat16, image_size, augment_name, + randaug_num_layers, randaug_magnitude) else: return preprocess_for_eval(image_bytes, use_bfloat16, image_size) diff --git a/tf_to_pytorch/convert_tf_to_pt/original_tf/utils.py b/tf_to_pytorch/convert_tf_to_pt/original_tf/utils.py index fa32ac3..61782ea 100644 --- a/tf_to_pytorch/convert_tf_to_pt/original_tf/utils.py +++ b/tf_to_pytorch/convert_tf_to_pt/original_tf/utils.py @@ -18,12 +18,15 @@ from __future__ import division from __future__ import print_function +import json import os +import sys + +from absl import logging import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import def build_learning_rate(initial_lr, @@ -50,7 +53,7 @@ def build_learning_rate(initial_lr, assert False, 'Unknown lr_decay_type : %s' % lr_decay_type if warmup_epochs: - tf.logging.info('Learning rate warmup_epochs: %d' % warmup_epochs) + logging.info('Learning rate warmup_epochs: %d', warmup_epochs) warmup_steps = int(warmup_epochs * steps_per_epoch) warmup_lr = ( initial_lr * tf.cast(global_step, tf.float32) / tf.cast( @@ -67,18 +70,18 @@ def build_optimizer(learning_rate, momentum=0.9): """Build optimizer.""" if optimizer_name == 'sgd': - tf.logging.info('Using SGD optimizer') + logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) elif optimizer_name == 'momentum': - tf.logging.info('Using Momentum optimizer') + logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=momentum) elif optimizer_name == 'rmsprop': - tf.logging.info('Using RMSProp optimizer') + logging.info('Using RMSProp optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum, epsilon) else: - tf.logging.fatal('Unknown optimizer:', optimizer_name) + logging.fatal('Unknown optimizer: %s', optimizer_name) return optimizer @@ -104,7 +107,7 @@ def _cross_replica_average(self, t, num_shards_per_group): group_assignment = [[ x for x in range(num_shards) if x // num_shards_per_group == y ] for y in range(num_groups)] - return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast( + return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast( num_shards_per_group, t.dtype) def _moments(self, inputs, reduction_axes, keep_dims): @@ -116,41 +119,45 @@ def _moments(self, inputs, reduction_axes, keep_dims): if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices. num_shards_per_group = 1 else: - num_shards_per_group = max(8, num_shards // 4) - tf.logging.info('TpuBatchNormalization with num_shards_per_group %s', - num_shards_per_group) + num_shards_per_group = max(8, num_shards // 8) + logging.info('TpuBatchNormalization with num_shards_per_group %s', + num_shards_per_group) if num_shards_per_group > 1: - # Each group has multiple replicas: here we compute group mean/variance by - # aggregating per-replica mean/variance. - group_mean = self._cross_replica_average(shard_mean, num_shards_per_group) - group_variance = self._cross_replica_average(shard_variance, - num_shards_per_group) - - # Group variance needs to also include the difference between shard_mean - # and group_mean. - mean_distance = tf.square(group_mean - shard_mean) - group_variance += self._cross_replica_average(mean_distance, - num_shards_per_group) + # Compute variance using: Var[X]= E[X^2] - E[X]^2. + shard_square_of_mean = tf.math.square(shard_mean) + shard_mean_of_square = shard_variance + shard_square_of_mean + group_mean = self._cross_replica_average( + shard_mean, num_shards_per_group) + group_mean_of_square = self._cross_replica_average( + shard_mean_of_square, num_shards_per_group) + group_variance = group_mean_of_square - tf.math.square(group_mean) return (group_mean, group_variance) else: return (shard_mean, shard_variance) -def drop_connect(inputs, is_training, drop_connect_rate): - """Apply drop connect.""" +class BatchNormalization(tf.layers.BatchNormalization): + """Fixed default name of BatchNormalization to match TpuBatchNormalization.""" + + def __init__(self, name='tpu_batch_normalization', **kwargs): + super(BatchNormalization, self).__init__(name=name, **kwargs) + + +def drop_connect(inputs, is_training, survival_prob): + """Drop the entire conv with given survival probability.""" + # "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf if not is_training: return inputs - # Compute keep_prob - # TODO(tanmingxing): add support for training progress. - keep_prob = 1.0 - drop_connect_rate - - # Compute drop_connect tensor + # Compute tensor. batch_size = tf.shape(inputs)[0] - random_tensor = keep_prob + random_tensor = survival_prob random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype) binary_tensor = tf.floor(random_tensor) - output = tf.div(inputs, keep_prob) * binary_tensor + # Unlike conventional way that multiply survival_prob at test time, here we + # divide survival_prob at training time, such that no addition compute is + # needed at test time. + output = tf.div(inputs, survival_prob) * binary_tensor return output @@ -164,12 +171,12 @@ def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path): with tf.gfile.GFile(saved_objective_path, 'r') as f: saved_objective = float(f.read()) if saved_objective > ckpt_objective: - tf.logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective) + logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective) return False filenames = tf.gfile.Glob(ckpt_path + '.*') if filenames is None: - tf.logging.info('No files to copy for checkpoint %s', ckpt_path) + logging.info('No files to copy for checkpoint %s', ckpt_path) return False # Clear the old folder. @@ -195,12 +202,204 @@ def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path): with tf.gfile.GFile(saved_objective_path, 'w') as f: f.write('%f' % ckpt_objective) - tf.logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir) + logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir) return True -# TODO(hongkuny): Consolidate this as a common library cross models. +def get_ema_vars(): + """Get all exponential moving average (ema) variables.""" + ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars') + for v in tf.global_variables(): + # We maintain mva for batch norm moving mean and variance as well. + if 'moving_mean' in v.name or 'moving_variance' in v.name: + ema_vars.append(v) + return list(set(ema_vars)) + + class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer): """Wrap keras DepthwiseConv2D to tf.layers.""" pass + + +class EvalCkptDriver(object): + """A driver for running eval inference. + + Attributes: + model_name: str. Model name to eval. + batch_size: int. Eval batch size. + image_size: int. Input image size, determined by model name. + num_classes: int. Number of classes, default to 1000 for ImageNet. + include_background_label: whether to include extra background label. + """ + + def __init__(self, + model_name, + batch_size=1, + image_size=224, + num_classes=1000, + include_background_label=False): + """Initialize internal variables.""" + self.model_name = model_name + self.batch_size = batch_size + self.num_classes = num_classes + self.include_background_label = include_background_label + self.image_size = image_size + + def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None): + """Restore variables from checkpoint dir.""" + sess.run(tf.global_variables_initializer()) + checkpoint = tf.train.latest_checkpoint(ckpt_dir) + if enable_ema: + ema = tf.train.ExponentialMovingAverage(decay=0.0) + ema_vars = get_ema_vars() + var_dict = ema.variables_to_restore(ema_vars) + ema_assign_op = ema.apply(ema_vars) + else: + var_dict = get_ema_vars() + ema_assign_op = None + + tf.train.get_or_create_global_step() + sess.run(tf.global_variables_initializer()) + saver = tf.train.Saver(var_dict, max_to_keep=1) + saver.restore(sess, checkpoint) + + if export_ckpt: + if ema_assign_op is not None: + sess.run(ema_assign_op) + saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) + saver.save(sess, export_ckpt) + + def build_model(self, features, is_training): + """Build model with input features.""" + del features, is_training + raise ValueError('Must be implemented by subclasses.') + + def get_preprocess_fn(self): + raise ValueError('Must be implemented by subclsses.') + + def build_dataset(self, filenames, labels, is_training): + """Build input dataset.""" + batch_drop_remainder = False + if 'condconv' in self.model_name and not is_training: + # CondConv layers can only be called with known batch dimension. Thus, we + # must drop all remaining examples that do not make up one full batch. + # To ensure all examples are evaluated, use a batch size that evenly + # divides the number of files. + batch_drop_remainder = True + num_files = len(filenames) + if num_files % self.batch_size != 0: + tf.logging.warn('Remaining examples in last batch are not being ' + 'evaluated.') + filenames = tf.constant(filenames) + labels = tf.constant(labels) + dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) + + def _parse_function(filename, label): + image_string = tf.read_file(filename) + preprocess_fn = self.get_preprocess_fn() + image_decoded = preprocess_fn( + image_string, is_training, image_size=self.image_size) + image = tf.cast(image_decoded, tf.float32) + return image, label + + dataset = dataset.map(_parse_function) + dataset = dataset.batch(self.batch_size, + drop_remainder=batch_drop_remainder) + + iterator = dataset.make_one_shot_iterator() + images, labels = iterator.get_next() + return images, labels + + def run_inference(self, + ckpt_dir, + image_files, + labels, + enable_ema=True, + export_ckpt=None): + """Build and run inference on the target images and labels.""" + label_offset = 1 if self.include_background_label else 0 + with tf.Graph().as_default(), tf.Session() as sess: + images, labels = self.build_dataset(image_files, labels, False) + probs = self.build_model(images, is_training=False) + if isinstance(probs, tuple): + probs = probs[0] + + self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt) + + prediction_idx = [] + prediction_prob = [] + for _ in range(len(image_files) // self.batch_size): + out_probs = sess.run(probs) + idx = np.argsort(out_probs)[::-1] + prediction_idx.append(idx[:5] - label_offset) + prediction_prob.append([out_probs[pid] for pid in idx[:5]]) + + # Return the top 5 predictions (idx and prob) for each image. + return prediction_idx, prediction_prob + + def eval_example_images(self, + ckpt_dir, + image_files, + labels_map_file, + enable_ema=True, + export_ckpt=None): + """Eval a list of example images. + + Args: + ckpt_dir: str. Checkpoint directory path. + image_files: List[str]. A list of image file paths. + labels_map_file: str. The labels map file path. + enable_ema: enable expotential moving average. + export_ckpt: export ckpt folder. + + Returns: + A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction + index and pred_prob is the top 5 prediction probability. + """ + classes = json.loads(tf.gfile.Open(labels_map_file).read()) + pred_idx, pred_prob = self.run_inference( + ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt) + for i in range(len(image_files)): + print('predicted class for image {}: '.format(image_files[i])) + for j, idx in enumerate(pred_idx[i]): + print(' -> top_{} ({:4.2f}%): {} '.format(j, pred_prob[i][j] * 100, + classes[str(idx)])) + return pred_idx, pred_prob + + def eval_imagenet(self, ckpt_dir, imagenet_eval_glob, + imagenet_eval_label, num_images, enable_ema, export_ckpt): + """Eval ImageNet images and report top1/top5 accuracy. + + Args: + ckpt_dir: str. Checkpoint directory path. + imagenet_eval_glob: str. File path glob for all eval images. + imagenet_eval_label: str. File path for eval label. + num_images: int. Number of images to eval: -1 means eval the whole + dataset. + enable_ema: enable expotential moving average. + export_ckpt: export checkpoint folder. + + Returns: + A tuple (top1, top5) for top1 and top5 accuracy. + """ + imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)] + imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob)) + if num_images < 0: + num_images = len(imagenet_filenames) + image_files = imagenet_filenames[:num_images] + labels = imagenet_val_labels[:num_images] + + pred_idx, _ = self.run_inference( + ckpt_dir, image_files, labels, enable_ema, export_ckpt) + top1_cnt, top5_cnt = 0.0, 0.0 + for i, label in enumerate(labels): + top1_cnt += label in pred_idx[i][:1] + top5_cnt += label in pred_idx[i][:5] + if i % 100 == 0: + print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format( + i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1))) + sys.stdout.flush() + top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images + print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5)) + return top1, top5 From 0253bfdb5850ed280cc6eca221bc77d14ae7d760 Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 19 Dec 2019 07:35:23 +0000 Subject: [PATCH 2/9] Added b8 model --- efficientnet_pytorch/model.py | 8 +++----- efficientnet_pytorch/utils.py | 2 ++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index 3cb47d5..25e0a44 100755 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -229,10 +229,8 @@ def get_image_size(cls, model_name): return res @classmethod - def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): - """ Validates model name. None that pretrained weights are only available for - the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ - num_models = 4 if also_need_pretrained_weights else 8 - valid_models = ['efficientnet-b'+str(i) for i in range(num_models)] + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] if model_name not in valid_models: raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index 198b3b4..93cc49f 100755 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -170,6 +170,8 @@ def efficientnet_params(model_name): 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), } return params_dict[model_name] From bf2109cb7401567268dccd414629707263a7c789 Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 19 Dec 2019 07:35:48 +0000 Subject: [PATCH 3/9] Improved developer experience --- .../convert_tf_to_pt/load_tf_weights.py | 2 +- tf_to_pytorch/convert_tf_to_pt/run.sh | 18 ++++++++++++++++++ .../pretrained_tensorflow/download.sh | 3 ++- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100755 tf_to_pytorch/convert_tf_to_pt/run.sh diff --git a/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py b/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py index bb0e5b9..0722a68 100644 --- a/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py +++ b/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py @@ -149,7 +149,7 @@ def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img parser = argparse.ArgumentParser( description='Convert TF model to PyTorch model and save for easier future loading') parser.add_argument('--model_name', type=str, default='efficientnet-b0', - help='efficientnet-b{N}, where N is an integer 0 <= N <= 7') + help='efficientnet-b{N}, where N is an integer 0 <= N <= 8') parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/', help='checkpoint file path') parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth', diff --git a/tf_to_pytorch/convert_tf_to_pt/run.sh b/tf_to_pytorch/convert_tf_to_pt/run.sh new file mode 100755 index 0000000..7e227c5 --- /dev/null +++ b/tf_to_pytorch/convert_tf_to_pt/run.sh @@ -0,0 +1,18 @@ +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b0 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b0/ --output_file ../pretrained_pytorch/efficientnet-b0.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b1 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b1/ --output_file ../pretrained_pytorch/efficientnet-b1.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b2 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b2/ --output_file ../pretrained_pytorch/efficientnet-b2.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b3 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b3/ --output_file ../pretrained_pytorch/efficientnet-b3.pth + + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b4 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b4/ --output_file ../pretrained_pytorch/efficientnet-b4.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b5 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b5/ --output_file ../pretrained_pytorch/efficientnet-b5.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b6 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b6/ --output_file ../pretrained_pytorch/efficientnet-b6.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b7 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b7/ --output_file ../pretrained_pytorch/efficientnet-b7.pth + +python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b8 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b8/ --output_file ../pretrained_pytorch/efficientnet-b8.pth diff --git a/tf_to_pytorch/pretrained_tensorflow/download.sh b/tf_to_pytorch/pretrained_tensorflow/download.sh index fd033de..ba7d7be 100755 --- a/tf_to_pytorch/pretrained_tensorflow/download.sh +++ b/tf_to_pytorch/pretrained_tensorflow/download.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash + # This script accepts a single command-line argument, which specifies which model to download. # Only the b0, b1, b2, and b3 models have been released, so your command must be one of them. @@ -9,6 +10,6 @@ # ./download.sh efficientnet-b3 MODEL=$1 -wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/${MODEL}.tar.gz +wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/${MODEL}.tar.gz tar xvf ${MODEL}.tar.gz rm ${MODEL}.tar.gz From 78ae36fd9821bd66e053cfc720d088514095ad9b Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 19 Dec 2019 07:36:29 +0000 Subject: [PATCH 4/9] Increased version --- efficientnet_pytorch/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/efficientnet_pytorch/__init__.py b/efficientnet_pytorch/__init__.py index e810cc2..2eb0ddd 100644 --- a/efficientnet_pytorch/__init__.py +++ b/efficientnet_pytorch/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1" +__version__ = "0.6.0" from .model import EfficientNet from .utils import ( GlobalParams, diff --git a/setup.py b/setup.py index b9600f9..06771ca 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ EMAIL = 'lmelaskyriazi@college.harvard.edu' AUTHOR = 'Luke' REQUIRES_PYTHON = '>=3.5.0' -VERSION = '0.5.1' +VERSION = '0.6.0' # What packages are required for this module to be executed? REQUIRED = [ From 9353120d625c4969281e15237027f20a14af32cc Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Jan 2020 04:28:58 +0000 Subject: [PATCH 5/9] Updated gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 3b36220..28c4bbe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Custom +tmp + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 459c157fbdabda8d9f3e4da0ce27772320880e1e Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Jan 2020 04:29:24 +0000 Subject: [PATCH 6/9] Added advprop preprocessing to eval script --- examples/imagenet/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/imagenet/main.py b/examples/imagenet/main.py index 31a62a2..a665e73 100644 --- a/examples/imagenet/main.py +++ b/examples/imagenet/main.py @@ -71,6 +71,8 @@ help='GPU id to use.') parser.add_argument('--image_size', default=224, type=int, help='image size') +parser.add_argument('--advprop', default=False, action='store_true', + help='use advprop or not') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' @@ -134,7 +136,7 @@ def main_worker(gpu, ngpus_per_node, args): # create model if 'efficientnet' in args.arch: # NEW if args.pretrained: - model = EfficientNet.from_pretrained(args.arch) + model = EfficientNet.from_pretrained(args.arch, advprop=args.advprop) print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) @@ -206,8 +208,11 @@ def main_worker(gpu, ngpus_per_node, args): # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + if args.advprop: + normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0) + else: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, From 578225c980036aae875f521e655db3cd57bc7b9b Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Jan 2020 04:29:59 +0000 Subject: [PATCH 7/9] More utilities for downloading and converting --- tf_to_pytorch/convert_tf_to_pt/rename.sh | 5 +++++ tf_to_pytorch/convert_tf_to_pt/run.sh | 17 ++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 tf_to_pytorch/convert_tf_to_pt/rename.sh diff --git a/tf_to_pytorch/convert_tf_to_pt/rename.sh b/tf_to_pytorch/convert_tf_to_pt/rename.sh new file mode 100644 index 0000000..aa79113 --- /dev/null +++ b/tf_to_pytorch/convert_tf_to_pt/rename.sh @@ -0,0 +1,5 @@ +for i in 0 1 2 3 4 5 6 7 8 +do + X=$(sha256sum efficientnet-b${i}.pth | head -c 8) + mv efficientnet-b${i}.pth efficientnet-b${i}-${X}.pth +done diff --git a/tf_to_pytorch/convert_tf_to_pt/run.sh b/tf_to_pytorch/convert_tf_to_pt/run.sh index 7e227c5..f80d5f5 100755 --- a/tf_to_pytorch/convert_tf_to_pt/run.sh +++ b/tf_to_pytorch/convert_tf_to_pt/run.sh @@ -1,18 +1,17 @@ python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b0 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b0/ --output_file ../pretrained_pytorch/efficientnet-b0.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b1 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b1/ --output_file ../pretrained_pytorch/efficientnet-b1.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b1 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b1/ --output_file ../pretrained_pytorch/efficientnet-b1.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b2 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b2/ --output_file ../pretrained_pytorch/efficientnet-b2.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b2 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b2/ --output_file ../pretrained_pytorch/efficientnet-b2.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b3 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b3/ --output_file ../pretrained_pytorch/efficientnet-b3.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b3 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b3/ --output_file ../pretrained_pytorch/efficientnet-b3.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b4 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b4/ --output_file ../pretrained_pytorch/efficientnet-b4.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b4 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b4/ --output_file ../pretrained_pytorch/efficientnet-b4.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b5 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b5/ --output_file ../pretrained_pytorch/efficientnet-b5.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b5 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b5/ --output_file ../pretrained_pytorch/efficientnet-b5.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b6 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b6/ --output_file ../pretrained_pytorch/efficientnet-b6.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b6 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b6/ --output_file ../pretrained_pytorch/efficientnet-b6.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b7 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b7/ --output_file ../pretrained_pytorch/efficientnet-b7.pth -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b7 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b7/ --output_file ../pretrained_pytorch/efficientnet-b7.pth - -python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b8 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b8/ --output_file ../pretrained_pytorch/efficientnet-b8.pth +# python ../convert_tf_to_pt/load_tf_weights.py --model_name efficientnet-b8 --tf_checkpoint ../pretrained_tensorflow/efficientnet-b8/ --output_file ../pretrained_pytorch/efficientnet-b8.pth From e22b46e35e31417b8ebe425d6bc2da335e70b652 Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Jan 2020 04:59:17 +0000 Subject: [PATCH 8/9] Updated readme --- README.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/README.md b/README.md index 9f8d6ca..2a4c15b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,28 @@ # EfficientNet PyTorch + +_IMPORTANT NOTE_: In the latest update, I switched hosting providers for the pretrained models, as the previous models were becoming extremely expensive to host. This _will_ break old versions of the library. I apologize, but I cannot afford to keep serving the models on the old provider. Everything should work properly if you update the library: +``` +pip install --upgrade efficientnet-pytorch +``` + +### Update (January 23, 2020) + +This update adds a new category of pre-trained model based on adversarial training, called _advprop_. It is important to note that the preprocessing required for the advprop pretrained models is slightly different from normal ImageNet preprocessing. As a result, by default, advprop models are not used. To load a model with advprop, use: +``` +model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True) +``` +There is also a new, large `efficientnet-b8` pretrained model that is only available in advprop form. When using these models, replace ImageNet preprocessing code as follows: +``` +if advprop: # for models using advprop pretrained weights + normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0) +else: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + +``` +This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)). + ### Update (October 15, 2019) This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89). From 396b06bc1489a21317ef81ac530274fe64cddc68 Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Jan 2020 05:00:19 +0000 Subject: [PATCH 9/9] Add advprop and switch hosting providers --- efficientnet_pytorch/model.py | 11 ++--------- efficientnet_pytorch/utils.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index 25e0a44..f2bc2af 100755 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -206,22 +206,15 @@ def from_name(cls, model_name, override_params=None): return cls(blocks_args, global_params) @classmethod - def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3): + def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): model = cls.from_name(model_name, override_params={'num_classes': num_classes}) - load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) out_channels = round_filters(32, model._global_params) model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) return model - @classmethod - def from_pretrained(cls, model_name, num_classes=1000): - model = cls.from_name(model_name, override_params={'num_classes': num_classes}) - load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) - - return model - @classmethod def get_image_size(cls, model_name): cls._check_model_name_is_valid(model_name) diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index 93cc49f..3f3e56f 100755 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -295,20 +295,35 @@ def get_model_params(model_name, override_params): return blocks_args, global_params -url_map = { - 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', - 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', - 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', - 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', - 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', - 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', - 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', - 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', +url_map_aa = { + 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth', } -def load_pretrained_weights(model, model_name, load_fc=True): +url_map_advprop = { + 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth', +} + + +def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): """ Loads pretrained weights, and downloads if loading for the first time. """ + # AutoAugment or Advprop (different preprocessing) + url_map = url_map_advprop if advprop else url_map_aa state_dict = model_zoo.load_url(url_map[model_name]) if load_fc: model.load_state_dict(state_dict)