Implementation of Vision Transformer in MLX. For further explanation and details on the ViT Architecture check out Yannic Kilcher's video.
import mlx.core as mx
from vit_mlx import ViT
v = ViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
img = mx.random.normal((1, 3, 256, 256))
preds = v(img) # (1, 1000)
image_size
: int | tuple[int, int]. Image size. If you have rectangular images, make sure your image size is the maximum of the width and heightpatch_size
: int. Size of patches.image_size
must be divisible bypatch_size
. The number of patches is:n = (image_size // patch_size) ** 2
andn
must be greater than 16.num_classes
: int. Number of classes to classify.dim
: int. Last dimension of output tensor after linear transformationnn.Linear(..., dim)
.depth
: int. Number of Transformer blocks.heads
: int. Number of heads in Multi-head Attention layer.mlp_dim
: int. Dimension of the MLP (FeedForward) layer.channels
: int, default3
. Number of image's channels.dropout
: float between[0, 1]
, default0.
. Dropout rate.emb_dropout
: float between[0, 1]
, default0
. Embedding dropout rate.pool
: string, eithercls
token pooling ormean
pooling
In an updated Version the authors introduced a simplified version of the ViT. They used a fixed 2d sinusoidal positional encoding instead of learning the positional encoding. They also introduced global average pooling, removed the dropout, increased the batch sizes to 4096, and used of RandAugment and MixUp augmentations.
import mlx.core as mx
from vit_mlx import ViT
v = ViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
)
img = mx.random.normal((1, 3, 256, 256))
preds = v(img) # (1, 1000)
The original Pytorch implementation from Dr. Phil 'Lucid' Wang can be found here