Skip to content

Commit 1e87509

Browse files
committed
if dropout > 0.0 disable Flash until pytorch fix. don't assert fail sigh
1 parent d8b1a94 commit 1e87509

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self, config):
4949
self.n_embd = config.n_embd
5050
self.dropout = config.dropout
5151
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
52-
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
52+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
5353
if not self.flash:
54-
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
54+
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
5555
# causal mask to ensure that attention is only applied to the left in the input sequence
5656
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
5757
.view(1, 1, config.block_size, config.block_size))
@@ -68,7 +68,6 @@ def forward(self, x):
6868
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
6969
if self.flash:
7070
# efficient attention using Flash Attention CUDA kernels
71-
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
7271
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
7372
else:
7473
# manual implementation of attention

0 commit comments

Comments
 (0)