@@ -49,9 +49,9 @@ def __init__(self, config):
49
49
self .n_embd = config .n_embd
50
50
self .dropout = config .dropout
51
51
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
52
- self .flash = hasattr (torch .nn .functional , 'scaled_dot_product_attention' )
52
+ self .flash = hasattr (torch .nn .functional , 'scaled_dot_product_attention' ) and self . dropout == 0.0
53
53
if not self .flash :
54
- print ("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention " )
54
+ print ("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0 " )
55
55
# causal mask to ensure that attention is only applied to the left in the input sequence
56
56
self .register_buffer ("bias" , torch .tril (torch .ones (config .block_size , config .block_size ))
57
57
.view (1 , 1 , config .block_size , config .block_size ))
@@ -68,7 +68,6 @@ def forward(self, x):
68
68
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
69
69
if self .flash :
70
70
# efficient attention using Flash Attention CUDA kernels
71
- assert self .dropout == 0.0 , "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
72
71
y = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = self .dropout , is_causal = True )
73
72
else :
74
73
# manual implementation of attention
0 commit comments