This MATLAB script defines a custom attention layer class attentionLayer
that can be used in deep learning models, particularly for sequence-to-sequence tasks or transformer-based architectures.
- Implements a multi-head attention mechanism
- Supports various input formats
- Optional causal masking
- Compatible with MATLAB's Deep Learning Toolbox
The attentionLayer
class is defined as a subclass of nnet.layer.Layer
and nnet.layer.Formattable
.
Nhead
: Number of attention headsInFormat
: Input format specificationUseMask
: Flag for using causal masking
Wq
: Query weight matrixWk
: Key weight matrixWv
: Value weight matrixWo
: Output weight matrix
InputDim
: Dimension of the inputQueryDim
: Dimension of the queryValueDim
: Dimension of the valueOutputDim
: Dimension of the outputNumberOfHead
: Number of attention headsInputFormat
: Format of the input tensor (e.g., "CBT", "CTB", "BTC", etc.)UseMask
: Boolean flag for causal masking (default: false)Name
: Name of the layer (default: "attentionLayer")
The predict method performs the forward pass of the attention layer:
- Reshapes the input tensor based on the specified input format
- Computes query, key, and value matrices
- Applies multi-head attention
- Produces the output
- MATLAB
- Deep Learning Toolbox
The layer supports various input tensor formats and automatically reshapes the input accordingly. Causal masking can be enabled for autoregressive models. The implementation uses MATLAB's dlarray for GPU compatibility.
Example
To create an instance of the attentionLayer
:
layer = attentionLayer(InputDim, QueryDim, ValueDim, OutputDim, NumberOfHead, InputFormat, UseMask, Name)
- The layer supports various input tensor formats and automatically reshapes the input accordingly.
- Causal masking can be enabled for autoregressive models.
- The implementation uses MATLAB's dlarray for GPU compatibility.
% Create an attention layer
attLayer = attentionLayer(512, 64, 64, 512, 8, "CBT", true, "MyAttentionLayer");
% Use the layer in a network
% ... (add other layers as needed)
layers = [ ...
% ... previous layers
attLayer
% ... subsequent layers
];
% Create and train the network
net = dlnetwork(layers);
% ... (training code)
For more information on using custom layers in MATLAB, refer to the MATLAB Deep Learning Toolbox documentation.