-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention for Neuron #939
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe wait until this PR is checked in. From what i can tell, your PR also has the remat bug not fixed. #942 (review)
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate): | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: end comments with . (here and everywhere)
8a92182
to
73a2808
Compare
key: Tensor, | ||
value: Tensor, | ||
bias: Tensor, | ||
causal: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, in this regard, I may ask for more, let's do general mask then, since we have want things beyond causal.
Thanks for all the reviews @ruomingp @kelvin-zou. I resolved all the comments, please let me know if any more changes are needed. |
73a2808
to
c226d03
Compare
I rebased the PR to avoid merge conflicts, can I please get a new approval? Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apoorvtintin I see quite a few unit tests failed, can you take a look?
42720ad
to
f7f06fd
Compare
can we in the file, disable the pytype check with annotation like FAILED: /home/runner/work/axlearn/axlearn/.pytype/pyi/axlearn/common/flash_attention/neuron_attention.pyi For more details, see https://google.github.io/pytype/errors.html#import-error |
925e0fe
to
2d38cb5
Compare
@Ruixuan I updated the PR with |
cefc3c0
to
b510e96
Compare
b510e96
to
2c9a285
Compare
Resolved all comments and fixed CI failures, @Ruixuan @kelvin-zou @ruomingp can we re-trigger the CI and merge this? |
This PR adds support for flash attention kernel for Neuron implemented through Neuron Kernel Interface (NKI).
The flash attention kernel works with TRN1 and TRN2.
This PR is a newer version of #883 from a different fork. All comments from the previous PR are addressed in this one. It has dropout support.
Dropout and Segment ID support in the flash attention kernel is in progress and will be available at a later date.