Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove extra wgmma_wait_group in flash attentinon #2399

Merged
merged 2 commits into from
Oct 26, 2023

Conversation

donproc
Copy link
Contributor

@donproc donproc commented Sep 27, 2023

Performance on H100 SXM5 HBM3
SM: 1830 MHZ

-[BATCH]-[N_HEADS]-[D_HEAD]-[mode]-[casual]-[seq_par]

fused-attention-batch4-head48-d64-fwd-True-True:
N_CTX Before After
0 1024.0 0.124135 0.120228
1 2048.0 0.391899 0.378483
2 4096.0 1.381547 1.340469
3 8192.0 5.182845 5.013851
fused-attention-batch4-head48-d64-fwd-True-False:
N_CTX Triton After
0 1024.0 0.124390 0.120371
1 2048.0 0.392177 0.379371
2 4096.0 1.380980 1.337169
3 8192.0 5.191552 5.013165
fused-attention-batch4-head48-d64-fwd-False-True:
N_CTX Triton After
0 1024.0 0.159597 0.153621
1 2048.0 0.563483 0.540353
2 4096.0 2.141308 2.053801
3 8192.0 8.244150 7.886597
fused-attention-batch4-head48-d64-fwd-False-False:
N_CTX Triton After
0 1024.0 0.160086 0.153992
1 2048.0 0.563423 0.541524
2 4096.0 2.146597 2.057333
3 8192.0 8.229922 7.886234

@donproc donproc requested a review from ptillet as a code owner September 27, 2023 00:45
@donproc donproc changed the title remove extra wgmma_wait_group in flash attentinon [WIP]remove extra wgmma_wait_group in flash attentinon Sep 27, 2023
auto CArg = dotOp.getOperand(2).dyn_cast<BlockArgument>();
if (!CArg || !CArg.hasOneUse())
auto CArg = dotOp.getOperand(2);
auto depend = dependOnTensorBlockArgument(dotOp.getOperand(2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need check more conditions here.
For example

for (%0, %1, %2) {
    %3 = add %1, %2
    %4 = dot %0, %1, %3
    .... some other ops
    scf.yield %x, %y, %4
}

It is a undefined behavior if we async launch the DotOp in the loop. Because at the time point we issue the AddOp, its operand %2 will be not ready yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it only accept conditions

  1. DotOp does not depends on operations with async semantics.
  2. DotOp depends on operations with async semantics, and with proper sync/fence operations between them.
  3. DotOp depends on itself, as a corner case of 2.


Operation *dotWait = nullptr;
if (!hasSyncDot && !hasWaitGroup) {
builder.setInsertionPointAfter(dot.getDefiningOp());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, DotWaitOp is coupled with DotAsyncOp. If hasSyncDot=true, it will not insert DotWaitOp, but it will replace DotOp with DotAsyncOp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And so as the outstanding DotWaitOp.

@ptillet ptillet marked this pull request as draft September 28, 2023 17:32
@ptillet
Copy link
Collaborator

ptillet commented Sep 28, 2023

marking this as draft until no longer WIP

@github-actions
Copy link

github-actions bot commented Oct 5, 2023

⚠️ This PR does not produce bitwise identical kernels as the branch it's merged against. Please check artifacts for details. Download the output file here.

@donproc donproc force-pushed the pipeline_remove branch 4 times, most recently from 32ccff2 to 6d0895c Compare October 17, 2023 00:37
@donproc donproc changed the title [WIP]remove extra wgmma_wait_group in flash attentinon remove extra wgmma_wait_group in flash attentinon Oct 17, 2023
@donproc donproc marked this pull request as ready for review October 17, 2023 01:12
@@ -1662,6 +1742,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
for (auto dotOp : allDots) {
if (dotOp != lastDot)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to check if there is a dotOp in allDots but not in dots?
Maybe find(dots.begin(), dots.end(), dotOp) != dots.end() will be more reasonable.

if (arg == result) {
break;
} else {
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need else {continue;} here.

auto op = v.getDefiningOp();
// root and not BlockArgument

auto iterArgs = forOp.getInitArgs();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a little misunderstanding. Do you have a concrete example to explain why we check forOp.getInitArgs() here but not forOp.getRegionIterArgs()?

@donproc donproc force-pushed the pipeline_remove branch 6 times, most recently from ea3528e to 7bd7185 Compare October 23, 2023 16:33
@github-actions
Copy link

⚠️ This PR does not produce bitwise identical kernels as the branch it's merged against. Please check artifacts for details. Download the output file here.

@github-actions
Copy link

⚠️ This PR does not produce bitwise identical kernels as the branch it's merged against. Please check artifacts for details. Download the output file here.

@ptillet
Copy link
Collaborator

ptillet commented Oct 26, 2023

This is great!

@ptillet ptillet enabled auto-merge (squash) October 26, 2023 16:30
@ptillet ptillet merged commit 0469d5f into triton-lang:main Oct 26, 2023
4 checks passed
@github-actions
Copy link

⚠️ This PR does not produce bitwise identical kernels as the branch it's merged against. Please check artifacts for details. Download the output file here.

Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for adding back this workaround? Does it break something to not have this if?

pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants