Yifan Zhou1,
Zeqi Xiao1,
Tianyi Wei1,
Shuai Yang2
Xingang Pan1
1S-Lab, Nanyang Technological University
2Wangxuan Institute of Computer Technology, Peking University
Paper: https://arxiv.org/abs/2512.16615
Official PyTorch implementation of Log-linear Sparse Attention (LLSA)
- [12/2025]: Code released.
-
Clone the repository:
git clone git@github.com:SingleZombie/LLSA.git cd LLSA pip install -e .
-
Install PyTorch in your Python environment.
-
Install the required Python libraries:
pip install -r requirements.txt
Replace the standard scaled dot-product attention with LLSA.
For sequence length < 16384 (128x128) :
from llsa.kernel.torch_op.flash_sparse_attention_res_1 import llsa_l1
- attn_output = F.scaled_dot_product_attention(
- query, key, value, dropout_p=0.0, is_causal=False)
+ attn_output = llsa_l1(query, key, value, block_size=16)For sequence length >= 16384 (128x128) :
from llsa.kernel.torch_op.flash_sparse_attention_res_2 import llsa_l2
- attn_output = F.scaled_dot_product_attention(
- query, key, value, dropout_p=0.0, is_causal=False)
+ attn_output = llsa_l2(query, key, value, block_size=16)Note:
- The current implementation supports only non-causal attention.
- The token length and
topkmust be powers of 2.- Future updates will address these limitations.
For non-sequential data (e.g., images or videos), we recommend reordering the data so that similar tokens have adjacent indices.
Please check the example in src/llsa/models/rope_dit_transformer_2d.py
def gen_permuatations(log_num_tokens: int):
# num_tokens = 4 ** (1 + log_num_tokens)
perm = torch.tensor([[0, 1], [2, 3]])
base_num = 4
for i in range(log_num_tokens):
length = perm.shape[-1]
perm = perm[None, :, :].expand(
4, -1, -1) + torch.arange(0, 4)[:, None, None] * base_num
perm = perm.reshape(2, 2, length, length)
perm = rearrange(perm, 'a b c d -> a c b d')
perm = perm.reshape(length * 2, length * 2)
base_num *= 4
perm = perm.flatten()
inv_perm = torch.empty_like(perm)
inv_perm[perm] = torch.arange(len(perm))
return perm, inv_perm
class DiT():
def init():
...
self.fwd_perms = {}
self.bwd_perms = {}
for log2_scale in range(2, 10):
img_size = 2 ** log2_scale
log_scale_m1 = log2_scale - 1
perm, inv_perm = gen_permuatations(int(log_scale_m1))
self.fwd_perms[img_size] = inv_perm
self.bwd_perms[img_size] = perm
def forward():
...
batch_size, _, height, width = hidden_states.shape
patch_height = height // self.patch_size
fwd_perm = self.fwd_perms[patch_height].to(hidden_states.device)
hidden_states = hidden_states[:, fwd_perm, :]
# attention blocks
...
bwd_perm = self.bwd_perms[patch_height]
hidden_states = hidden_states[:, bwd_perm, :]
...In the paper, we validate LLSA on pure pixel DiT (no VAE, no patchification) generation up to
python gen_fid_log.py configs/training/train_rope_dit_S_32_rms.json ffhq_32.pth
python gen_fid_log.py configs/training/train_rope_dit_S_128_rms_ft_llsa_l2.json ffhq_128.pthpython train.py configs/training/train_rope_dit_S_32_rms.json
# or
# accelereate launch train.py configs/training/train_rope_dit_S_32_rms.jsonpython train.py configs/training/train_rope_dit_S_128_rms_ft_llsa_l2.json
# We provide a DiT with full attention for comparison
# python train.py configs/training/train_rope_dit_S_128_rms_ft.jsonpython test_fid.py configs/training/train_rope_dit_S_128_rms_ft_llsa_l2.json ffhq_128.pth \
--n_sample_data_batch 200 \
--valid_batch_size 50- Diffusers: Our project is built on diffusers.
