-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove extra wgmma_wait_group in flash attentinon #2399
Conversation
auto CArg = dotOp.getOperand(2).dyn_cast<BlockArgument>(); | ||
if (!CArg || !CArg.hasOneUse()) | ||
auto CArg = dotOp.getOperand(2); | ||
auto depend = dependOnTensorBlockArgument(dotOp.getOperand(2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we need check more conditions here.
For example
for (%0, %1, %2) {
%3 = add %1, %2
%4 = dot %0, %1, %3
.... some other ops
scf.yield %x, %y, %4
}
It is a undefined behavior if we async launch the DotOp in the loop. Because at the time point we issue the AddOp, its operand %2 will be not ready yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it only accept conditions
- DotOp does not depends on operations with async semantics.
- DotOp depends on operations with async semantics, and with proper sync/fence operations between them.
- DotOp depends on itself, as a corner case of 2.
|
||
Operation *dotWait = nullptr; | ||
if (!hasSyncDot && !hasWaitGroup) { | ||
builder.setInsertionPointAfter(dot.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, DotWaitOp is coupled with DotAsyncOp. If hasSyncDot=true
, it will not insert DotWaitOp
, but it will replace DotOp
with DotAsyncOp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And so as the outstanding DotWaitOp
.
marking this as draft until no longer WIP |
|
32ccff2
to
6d0895c
Compare
6d0895c
to
1892bf4
Compare
@@ -1662,6 +1742,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { | |||
// TODO: merge this with the rest of the pipelining transformation and look at | |||
// a better representation for async dots. | |||
tt::DotOp lastDot = dots.back(); | |||
for (auto dotOp : allDots) { | |||
if (dotOp != lastDot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to check if there is a dotOp in allDots
but not in dots
?
Maybe find(dots.begin(), dots.end(), dotOp) != dots.end()
will be more reasonable.
if (arg == result) { | ||
break; | ||
} else { | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't need else {continue;}
here.
auto op = v.getDefiningOp(); | ||
// root and not BlockArgument | ||
|
||
auto iterArgs = forOp.getInitArgs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a little misunderstanding. Do you have a concrete example to explain why we check forOp.getInitArgs()
here but not forOp.getRegionIterArgs()
?
ea3528e
to
7bd7185
Compare
|
7bd7185
to
fd38935
Compare
|
fd38935
to
70838c9
Compare
70838c9
to
5df7a6f
Compare
This is great! |
|
Value loopNotEmpty = builder.create<arith::CmpIOp>( | ||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), | ||
forOp.getUpperBound()); | ||
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason for adding back this workaround? Does it break something to not have this if
?
…lang#2399) Co-authored-by: dongdongl <[email protected]>
Performance on H100 SXM5 HBM3
SM: 1830 MHZ
-[BATCH]-[N_HEADS]-[D_HEAD]-[mode]-[casual]-[seq_par]
fused-attention-batch4-head48-d64-fwd-True-True:
N_CTX Before After
0 1024.0 0.124135 0.120228
1 2048.0 0.391899 0.378483
2 4096.0 1.381547 1.340469
3 8192.0 5.182845 5.013851
fused-attention-batch4-head48-d64-fwd-True-False:
N_CTX Triton After
0 1024.0 0.124390 0.120371
1 2048.0 0.392177 0.379371
2 4096.0 1.380980 1.337169
3 8192.0 5.191552 5.013165
fused-attention-batch4-head48-d64-fwd-False-True:
N_CTX Triton After
0 1024.0 0.159597 0.153621
1 2048.0 0.563483 0.540353
2 4096.0 2.141308 2.053801
3 8192.0 8.244150 7.886597
fused-attention-batch4-head48-d64-fwd-False-False:
N_CTX Triton After
0 1024.0 0.160086 0.153992
1 2048.0 0.563423 0.541524
2 4096.0 2.146597 2.057333
3 8192.0 8.229922 7.886234