Skip to content

Commit

Permalink
remove extra wait in pipeline pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dongdongl committed Oct 23, 2023
1 parent b0c166b commit ea3528e
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 16 deletions.
142 changes: 133 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,9 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
void asyncLaunchDots(scf::ForOp forOp);
void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info,
int numStages);

bool dependAsync(tt::DotOp op, scf::ForOp forOp, Operation **firstUse);
int getDependencyIndex(Value v, scf::ForOp forOp);
void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, bool hasDotWait0);
ConsumerReleaseMap consumerReleaseMap;
};

Expand Down Expand Up @@ -1612,15 +1614,94 @@ void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp,
}
}

int PipelinePass::getDependencyIndex(Value v, scf::ForOp forOp) {
auto op = v.getDefiningOp();
// root and not BlockArgument
auto iterArgs = forOp.getRegionIterArgs();
if (isa<BlockArgument>(v)) {
auto iter = std::find(iterArgs.begin(), iterArgs.end(), v);
if (iter != iterArgs.end())
return std::distance(iterArgs.begin(), iter);
} else {
if (!op)
return -1;
for (auto operand : op->getOperands()) {
auto idx = getDependencyIndex(operand, forOp);
if (idx != -1)
return idx;
}
}
// -1 indicates no dependency
return -1;
}

bool PipelinePass::dependAsync(tt::DotOp dotOp, scf::ForOp forOp,
Operation **firstUse) {
// no dependency on other async ops, we only check C here
auto CArg = dotOp.getOperand(2);
auto result = dotOp.getResult();
auto yieldOp = forOp.getBody()->getTerminator();
int argIdx = -1;
auto iter = std::find(yieldOp->getOperands().begin(),
yieldOp->getOperands().end(), result);
if (iter != yieldOp->getOperands().end())
argIdx = std::distance(yieldOp->getOperands().begin(), iter);
if (argIdx == -1)
return false;
for (auto operand : dotOp.getOperands()) {
auto dependIdx = getDependencyIndex(operand, forOp);
if (argIdx == dependIdx) {
auto iterArgs = forOp.getRegionIterArgs();
*firstUse = iterArgs[dependIdx].use_begin().getUser();
return true;
}
}
return false;
}

void PipelinePass::removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp,
bool hasDotWait0) {
if (hasDotWait0) {
for (auto &item : consumerReleaseMap) {
auto &m = item.second.consumerStageMap;
if (m.count(dotWaitOp)) {
m.erase(dotWaitOp);
}
}
dotWaitOp->erase();
}
}

void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
Block *loop = forOp.getBody();

auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) {
if (!op)
return -1l;
auto lastOp = op;
while (op->getBlock()->getParentOp() != forOp) {
lastOp = op;
op = op->getBlock()->getParentOp();
}
return std::distance(lastOp->getBlock()->getParent()->begin(),
lastOp->getBlock()->getIterator());
};
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
bool depend = false;
bool hasSyncDot = false;
bool hasDotWait0 = false;
SmallVector<tt::DotOp> allDots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotWaitOp = dyn_cast<tt::nvidia_gpu::DotWaitOp>(&op)) {
auto attr = dotWaitOp->getAttrOfType<IntegerAttr>("pendings");
auto pendingCount = attr.getInt();
if (pendingCount == 0)
hasDotWait0 = true;
}
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
allDots.push_back(dotOp);
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
if (resEnc && resEnc.isHopper()) {
Expand All @@ -1634,10 +1715,22 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
valid = false;
if (!isa<scf::YieldOp>(*dot.getUsers().begin()))
valid = false;

// C should be a block argument
auto CArg = dotOp.getOperand(2).dyn_cast<BlockArgument>();
if (!CArg || !CArg.hasOneUse())
Operation *firstUse = nullptr;
depend = dependAsync(dotOp, forOp, &firstUse);
for (auto tempInAll : allDots) {
if (tempInAll == dotOp)
continue;
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
if (iter != dots.end())
continue;
auto db = getBlockNumInFor(tempInAll, forOp);
auto fb = getBlockNumInFor(firstUse, forOp);
if (db < fb ||
(db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse)))
hasSyncDot = true;
}
auto CArg = dotOp.getOperand(2);
if ((depend && !hasSyncDot) || !CArg.hasOneUse())
valid = false;

if (valid) {
Expand All @@ -1662,6 +1755,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
auto loc = lastDot.getLoc();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
Expand All @@ -1678,16 +1772,46 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
dotOp->erase();
}

// check whether there is still has sync dot
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
hasSyncDot = true;
}
}
hasDotWait0 = hasDotWait0 || hasSyncDot;

// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
for (unsigned resultIndex : resultNeedSync) {
Value result = forOp->getResult(resultIndex);
SmallVector<Type> resultTypes(resultNeedSync.size());
SmallVector<Value> yieldThenValues(resultNeedSync.size());
SmallVector<Value> yieldElseValues(resultNeedSync.size());
for (int i = 0; i < resultNeedSync.size(); ++i) {
resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType();
yieldThenValues[i] = forOp->getResult(resultNeedSync[i]);
yieldElseValues[i] = forOp->getResult(resultNeedSync[i]);
}
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
/*hasElse*/ true);
builder.setInsertionPointToStart(ifOp.thenBlock());
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
result.replaceAllUsesExcept(ifOp.getResult(i), dotWait);
yieldThenValues[i] = dotWait.getResult();
}
auto yieldOpThen = builder.create<scf::YieldOp>(loc, yieldThenValues);
builder.setInsertionPointToEnd(ifOp.elseBlock());
auto yieldOpElse = builder.create<scf::YieldOp>(loc, yieldElseValues);

// 3. potentially remove redundant dot_wait after dot_async if having mutiple
// DotOp in the loop
removeExtraWait(dotWait, hasDotWait0);
}

Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,
Expand Down
79 changes: 72 additions & 7 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import triton.ops


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)
])
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('seq_par', [False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
import os
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
Expand Down Expand Up @@ -57,3 +59,66 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
# temporary env var control end
os.putenv("ENABLE_TMA", enable_tma)


try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'casual': casual,
'seq_par': seq_par}
) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]]


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)
Loading

0 comments on commit ea3528e

Please sign in to comment.