1 unstable release
| 0.1.0 | Nov 13, 2025 |
|---|
#1297 in Machine learning
56KB
159 lines
burn_attention
Flash Attention v3 implementation for the Burn deep learning framework.
Overview
This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:
- WGPU (default): Cross-platform GPU support via WebGPU
- CubeCL: High-performance compute kernels
- CUDA: Direct CUDA support for NVIDIA GPUs
Features
- ✅ Standard scaled dot-product attention
- ✅ Causal masking for autoregressive models
- ✅ Custom attention masks
- ✅ Configurable softmax scaling
- ✅ Multiple backend support (WGPU, CubeCL, CUDA)
- ✅ Comprehensive test suite
- ✅ Criterion benchmarks for performance testing
Installation
Add this to your Cargo.toml:
[dependencies]
burn_attention = "0.1"
Feature Flags
wgpu(default): Enable WGPU backendcubecl: Enable CubeCL backendcuda: Enable CUDA backend
Example with CUDA support:
[dependencies]
burn_attention = { version = "0.1", features = ["cuda"] }
Usage
Basic Example
use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;
type Backend = NdArray;
fn main() {
let device = Default::default();
// Create input tensors
let batch_size = 2;
let num_heads = 8;
let seq_len = 128;
let head_dim = 64;
let query = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
// Compute attention
let output = FlashAttentionV3::forward(query, key, value, None, false);
println!("Output shape: {:?}", output.dims());
}
Causal Attention
For autoregressive models, use causal masking:
let output = FlashAttentionV3::forward(query, key, value, None, true);
Custom Configuration
use burn_attention::FlashAttentionV3Config;
let config = FlashAttentionV3Config {
causal: true,
dropout_p: 0.1,
softmax_scale: Some(0.125),
block_size_q: 128,
block_size_k: 128,
};
let output = FlashAttentionV3::forward_with_config(
query,
key,
value,
None,
config,
);
Benchmarks
Run benchmarks with:
cargo bench
This will run throughput benchmarks for various sequence lengths and batch sizes.
Testing
Run the test suite:
cargo test
The test suite includes:
- Unit tests for basic functionality
- Numerical correctness tests comparing against reference implementation
- Property-based tests for attention output
Implementation Details
This implementation follows the Flash Attention v3 algorithm with optimizations for:
- Memory Efficiency: Tiled computation to reduce memory usage
- Numerical Stability: Online softmax computation
- Performance: Kernel fusion and optimized memory access patterns
Tensor Shapes
- Query:
[batch_size, num_heads, seq_len_q, head_dim] - Key:
[batch_size, num_heads, seq_len_k, head_dim] - Value:
[batch_size, num_heads, seq_len_k, head_dim] - Output:
[batch_size, num_heads, seq_len_q, head_dim]
References
License
This project is licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Dependencies
~94–135MB
~2.5M SLoC