-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas/Triton segfault on H100 #17356
Comments
For completeness, here is the crashing HLO:
|
Seems like a Triton crash, but could you provide segfault with asan? |
I have not run with ASAN, but debug version fails on OOB access:
|
I think I narrowed it down to this PR: triton-lang/triton#4492. After this PR, the new TTGIR for your test looks like this:
Running this through triton-opt -convert-triton-gpu-to-llvm repros the crash. The interesting part I think is here:
The PR is trying to more efficiently chain dots so instead of doing an extra local_alloc and putting it back into shared memory it keeps it in registers for the next mma. However, hopper wgmma LHS in registers is not supported yet in triton as far as I know so this can't work. I'm not sure exactly what the correct solution is yet; maybe it needs a check to not do this in specific situations. But I do know that @ggengnv is working on exactly this (openxla/triton#18) so perhaps they can spot the issue quicker than I can? |
@vwbaker I think the name of my PR is a bit misleading :) - WGMMA with LHS in registers does already exist in Triton, specifically for chained MMAs. This means keeping MMA1's accumulator in registers, possibly do some casting (and maybe shuffling? I haven't looked at the code closely), and then using these registers as LHS for MMA2. My PR was for loading from shmem into registers for MMA. The TTGIR level optimization for chained MMA's is done in The LLVM lowering for this is in OTOH, the PR that you linked pertains to "MMA to MMA" layout conversion ("MMA layout" being accumulator layout). It might have modified the LLVM lowering for this following TTGIR:
which was later passed to this MMA -> dot conversion
and then to WGMMA, where the crash happened. I'm not familiar with that part of the codebase, but I think it's likely a bug was added in the PR in |
This looks like the same root cause as triton-lang/triton#4502. The repro there is very similar and the same workaround (changing from |
triton-lang#4492 started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves openxla/xla#17356
…4803) #4492 started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves openxla/xla#17356 Co-authored-by: Tori <[email protected]>
This should be resolved by triton-lang/triton#4803 and will be merged into openxla/xla in next week's integrate |
I can confirm this is fixed now. |
…riton-lang#4803) triton-lang#4492 started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves openxla/xla#17356 Co-authored-by: Tori <[email protected]>
…riton-lang#4803) triton-lang#4492 started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves openxla/xla#17356 Co-authored-by: Tori <[email protected]>
After the commit cb304cf, JAX crashes in Triton on H100 with the following repro:
The stack trace:
Here is the JAX version:
The text was updated successfully, but these errors were encountered: