Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention for Neuron #939

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

@apoorvtintin apoorvtintin commented Jan 21, 2025

This PR adds support for flash attention kernel for Neuron implemented through Neuron Kernel Interface (NKI).

The flash attention kernel works with TRN1 and TRN2.

This PR is a newer version of #883 from a different fork. All comments from the previous PR are addressed in this one. It has dropout support.

Dropout and Segment ID support in the flash attention kernel is in progress and will be available at a later date.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe wait until this PR is checked in. From what i can tell, your PR also has the remat bug not fixed. #942 (review)

axlearn/common/flash_attention/neuron_attention.py Outdated Show resolved Hide resolved
axlearn/common/flash_attention/utils.py Show resolved Hide resolved


def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate):
# Get the batch size, sequence lengths, number of heads, and hidden dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: end comments with . (here and everywhere)

axlearn/common/flash_attention/neuron_attention.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline_upstream_fa branch 4 times, most recently from 8a92182 to 73a2808 Compare January 23, 2025 02:32
key: Tensor,
value: Tensor,
bias: Tensor,
causal: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, in this regard, I may ask for more, let's do general mask then, since we have want things beyond causal.

@apoorvtintin
Copy link
Contributor Author

Thanks for all the reviews @ruomingp @kelvin-zou. I resolved all the comments, please let me know if any more changes are needed.

@apoorvtintin
Copy link
Contributor Author

I rebased the PR to avoid merge conflicts, can I please get a new approval? Thank you!

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apoorvtintin I see quite a few unit tests failed, can you take a look?

@apoorvtintin apoorvtintin force-pushed the mainline_upstream_fa branch 2 times, most recently from 42720ad to f7f06fd Compare January 30, 2025 16:16
@Ruixuan
Copy link

Ruixuan commented Feb 5, 2025

can we in the file, disable the pytype check with annotation like # pytype: disable=import-error

FAILED: /home/runner/work/axlearn/axlearn/.pytype/pyi/axlearn/common/flash_attention/neuron_attention.pyi
/opt/hostedtoolcache/Python/3.10.16/x64/bin/python -m pytype.single --imports_info /home/runner/work/axlearn/axlearn/.pytype/imports/axlearn.common.flash_attention.neuron_attention.imports --module-name axlearn.common.flash_attention.neuron_attention --platform linux -V 3.10 -o /home/runner/work/axlearn/axlearn/.pytype/pyi/axlearn/common/flash_attention/neuron_attention.pyi --analyze-annotated --nofail --quick /home/runner/work/axlearn/axlearn/axlearn/common/flash_attention/neuron_attention.py
File "/home/runner/work/axlearn/axlearn/axlearn/common/flash_attention/neuron_attention.py", line 9, in : Can't find module 'jax_neuronx'. [import-error]
File "/home/runner/work/axlearn/axlearn/axlearn/common/flash_attention/neuron_attention.py", line 10, in : Can't find module 'neuronxcc'. [import-error]
File "/home/runner/work/axlearn/axlearn/axlearn/common/flash_attention/neuron_attention.py", line 12, in : Can't find module 'neuronxcc.nki.kernels.attention'. [import-error]

For more details, see https://google.github.io/pytype/errors.html#import-error
[10/540] check axlearn.cloud.common.types
[11/540] infer axlearn.cloud.common.config
[12/540] check axlearn.common.utils
ninja: build stopped: subcommand failed.
Computing dependencies
Analyzing 535 sources with 0 local dependencies
Leaving directory '.pytype'
Error: Process completed with exit code 1.

@apoorvtintin apoorvtintin force-pushed the mainline_upstream_fa branch 3 times, most recently from 925e0fe to 2d38cb5 Compare February 6, 2025 04:54
@apoorvtintin
Copy link
Contributor Author

@Ruixuan I updated the PR with # pytype: disable=import-error and verified that pytype passes. Can I get a new approval and re-trigger the CI? Thank you!

@apoorvtintin apoorvtintin force-pushed the mainline_upstream_fa branch 2 times, most recently from cefc3c0 to b510e96 Compare February 7, 2025 02:31
@apoorvtintin
Copy link
Contributor Author

Resolved all comments and fixed CI failures, @Ruixuan @kelvin-zou @ruomingp can we re-trigger the CI and merge this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants