Implementation of the segmentation neural networks for PyTorch with new features such as:
- 🐟 No backbones, the architectures remain simple to understand and train.
- 💾 Memory-efficient version (trade-off between memory and speed).
- 🖼 Works with any input size (not only powers of 2 anymore).
- 👁 Different types of upsampling (transposed convolution, upsampling and pixel shuffle).
- 🏊♀️ Different types of pooling (max-pooling, avg-pooling, blur-pooling).
- 🏗 The depth and width of the models are fully configurable.
- 🔬 Early-transition can be enabled when the input images are big.
- 👸🏼 The activation functions of all layers can be modified to something trendier.
For the Tiramisu architecture:
- 🎉 Won a competition (Adipocyte Cell Imaging Challenge)! Preprint of the winners is here.
- 🎉 Was used in a NeurIPS paper! Abstract and paper are here.
Support for the following neural networks:
The package can be installed from the repository with:
> pip3 install octopytorch
You can try the model in Python with:
from functools import partial
import torch
from torch import nn
import octopytorch as octo
module_bank = octo.DEFAULT_MODULE_BANK.copy()
# Dropout
module_bank[octo.ModuleType.DROPOUT] = partial(nn.Dropout2d, p=0.2, inplace=True)
# Every activation in the model is going to be a GELU (Gaussian Error Linear
# Units function). GELU(x) = x * Φ(x)
# See: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
module_bank[octo.ModuleType.ACTIVATION] = nn.GELU
# Example for segmentation:
module_bank[octo.ModuleType.ACTIVATION_FINAL] = partial(nn.LogSoftmax, dim=1)
# Example for regression (default):
#module_bank[octo.ModuleType.ACTIVATION_FINAL] = nn.Identity
model = octo.models.Tiramisu(
in_channels = 3, # RGB images
out_channels = 5, # 5-channel output (5 classes)
init_conv_filters = 48, # Number of channels outputted by the 1st convolution
structure = (
[4, 4, 4, 4, 4], # Down blocks
4, # bottleneck layers
[4, 4, 4, 4, 4], # Up blocks
),
growth_rate = 12, # Growth rate of the DenseLayers
compression = 1.0, # No compression
early_transition = False, # No early transition
include_top = True, # Includes last layer and activation
checkpoint = False, # No memory checkpointing
module_bank = module_bank # Modules to use
)
# Initializes all the convolutional kernel weights.
model.initialize_kernels(nn.init.kaiming_uniform_, conv=True)
# Shows some information about the model.
model.summary()
This example tiramisu network has a depth of len(down_blocks) = 5, meaning that the input images should be at least 32x32 pixels (i.e. 2^5=32).
The parameters of the constructor are explained as following:
- in_channels: The number of channels of the input image (e.g. 1 for grayscale, 3 for RGB).
- out_channels: The number of output channels (e.g. C for C classes).
- init_conv_filters: The number of filters in the very first convolution.
- structure: Divided in three parts (down blocks, bottleneck and up blocks) which describe the depth of the neural network (how many levels there are) and how many DenseLayers each of those levels have.
- growth_rate: Describes the size of each convolution in the DenseLayers. At each conv. the DenseLayer grows by this many channels.
- compression: The compression of the DenseLayers to reduce the memory footprint and computational complexity of the model.
- early_transition: Optimization where the input is downscaled by a factor of two after the first layer by using a down-transition (without skip-connection) early on.
- include_top: Including the top layer, with the last convolution and activation (True) or returns the embeddings for each pixel.
- checkpoint: Activates memory checkpointing, a memory efficient version of the Tiramisu. See: https://arxiv.org/pdf/1707.06990.pdf
- module_bank: The bank of layers the Tiramisu uses to build itself. See next subsection for details.
The Tiramisu base layers (e.g. Conv2D, activation functions, etc.) can be set to different types of layers. This was introduced to wrap many arguments of the main class under the same object and increase the flexibility to change layers.
The layers that can be redefined are:
- CONV: Convolution operations in the full model. Change with care.
- CONV_INIT: Initial (1st) convolution operation. Note: Kernel size must be provided.
- CONV_FINAL: Final convolution. Will be set to a 1x1 kernel and reduce output to C classes.
- BATCHNORM: Batch normalization in the full model.
- POOLING: Pooling operation. Note: must reduce input size by a factor of two. If the size is odd, round up to the closest integer.
- DROPOUT: Dropout. The p value must be provided through partial.
- UPSAMPLE: Upsampling operation (must be by a factor of two)
- ACTIVATION: Activation function to use everywhere
- ACTIVATION_FINAL: Act. function at the last layer (e.g. softmax, nn.Identity)
Notes:
- For pooling common options are nn.MaxPool2d, nn.AvgPool2d, or even tiramisu.layers.blurpool.BlurPool2d.
- For upsampling, there are some presets: UPSAMPLE_NEAREST (default), UPSAMPLE_PIXELSHUFFLE, UPSAMPLE_TRANSPOSE (known to produce artifacts).
- The layers can be set to nn.Identity to be bypassed (e.g. if one wants to remove the dropout layer, or the final activation).
- The partial function can prefill some of the arguments to be used in the model.
- Make sure the features you are interested in fit approximately the perceptive field. For instance, if you have an object that measures 50 pixels, you need at approx. 6 levels of resolution in down/up blocks (since 2^6=64 > 50). Or use early transition, which down samples the input by two.
- If you need to reduce the memory footprint, trying out the efficient version, enabling the early transition is a great way to start. Then, using compression, reducing the growth rate and finally the number of dense blocks in the down/up blocks.
- Use upsampling instead of transposed convolution, seriously. Transposed convolutions are hard to manage and may create a lot of gridding artefacts.
- Use blurpooling if you want the neural network to be shift-invariant (good accuracy even when shifting the input).
- The model creates border artifacts at the edge, which can be mitigated by changing the padding_mode argument of the Conv2d in the module bank. For instance, using "reflect" instead of "zeros" will create a smooth continuation in the boundaries instead of an edge.
- Pytorch - Version >=1.4.0 (for memory efficient version)
See also the list of contributors who participated in this project. For contributing, make sure the code passes the checks of Pylama, Bandit and Mypy. Additionally, the code is formatted with Black.
This project is licensed under the MIT License - see the LICENSE.md file for details.
Many thanks to @RaphaelaHeil for her much appreciated advices on best practices.