Skip to content

Commit

Permalink
[OPTIMIZER] Remove extra wgmma_wait_group in flash attention (triton-…
Browse files Browse the repository at this point in the history
…lang#2399)

Co-authored-by: dongdongl <[email protected]>
  • Loading branch information
donproc and dongdongl authored Oct 26, 2023
1 parent cfae7e2 commit 0469d5f
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 8 deletions.
131 changes: 123 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
void asyncLaunchDots(scf::ForOp forOp);
void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info,
int numStages);

bool selfDepend(tt::DotOp op, scf::ForOp forOp, Operation **firstUse);
void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, bool hasDotWait0);
ConsumerReleaseMap consumerReleaseMap;
};

Expand Down Expand Up @@ -1612,13 +1613,89 @@ void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp,
}
}

bool PipelinePass::selfDepend(tt::DotOp dotOp, scf::ForOp forOp,
Operation **firstUse) {
std::function<bool(Value, int, scf::ForOp)> dependOn =
[&dependOn](Value v, int argId, scf::ForOp forOp) {
auto op = v.getDefiningOp();
if (isa<BlockArgument>(v)) {
auto iterArgs = forOp.getRegionIterArgs();
auto iter = std::find(iterArgs.begin(), iterArgs.end(), v);
if (iter != iterArgs.end())
return std::distance(iterArgs.begin(), iter) == argId;
} else {
if (!op)
return false;
for (auto operand : op->getOperands()) {
if (dependOn(operand, argId, forOp))
return true;
}
}
return false;
};
auto result = dotOp.getResult();
auto yieldOp = forOp.getBody()->getTerminator();
int argIdx = -1;
auto iter = std::find(yieldOp->getOperands().begin(),
yieldOp->getOperands().end(), result);
if (iter != yieldOp->getOperands().end())
argIdx = std::distance(yieldOp->getOperands().begin(), iter);
if (argIdx == -1)
return false;
for (auto operand : dotOp.getOperands()) {
if (dependOn(operand, argIdx, forOp)) {
auto iterArgs = forOp.getRegionIterArgs();
*firstUse = iterArgs[argIdx].use_begin().getUser();
return true;
}
}
return false;
}

void PipelinePass::removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp,
bool hasDotWait0) {
if (hasDotWait0) {
for (auto &item : consumerReleaseMap) {
auto &m = item.second.consumerStageMap;
if (m.count(dotWaitOp)) {
m.erase(dotWaitOp);
}
}
dotWaitOp->erase();
}
}

void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
Block *loop = forOp.getBody();

auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) {
if (!op)
return -1l;
auto lastOp = op;
while (op->getBlock()->getParentOp() != forOp) {
lastOp = op;
op = op->getBlock()->getParentOp();
}
return std::distance(lastOp->getBlock()->getParent()->begin(),
lastOp->getBlock()->getIterator());
};
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
bool hasSyncDot = false;
bool hasDotWait0 = false;
SmallVector<tt::DotOp> allDots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotWaitOp = dyn_cast<tt::nvidia_gpu::DotWaitOp>(&op)) {
auto attr = dotWaitOp->getAttrOfType<IntegerAttr>("pendings");
auto pendingCount = attr.getInt();
if (pendingCount == 0)
hasDotWait0 = true;
}
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
allDots.push_back(dotOp);
}
}
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
Expand All @@ -1635,9 +1712,22 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
if (!isa<scf::YieldOp>(*dot.getUsers().begin()))
valid = false;

// C should be a block argument
auto CArg = dotOp.getOperand(2).dyn_cast<BlockArgument>();
if (!CArg || !CArg.hasOneUse())
Operation *firstUse = nullptr;
selfDepend(dotOp, forOp, &firstUse);
bool selfDirectDepend = (dotOp == firstUse);
for (auto tempInAll : allDots) {
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
if (iter != dots.end())
continue;
auto db = getBlockNumInFor(tempInAll, forOp);
auto fb = getBlockNumInFor(firstUse, forOp);
if (db < fb ||
(db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse)))
hasSyncDot = true;
}
auto CArg = dotOp.getOperand(2);
if (!(selfDirectDepend || (!selfDirectDepend && hasSyncDot)) ||
!CArg.hasOneUse())
valid = false;

if (valid) {
Expand All @@ -1662,6 +1752,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
auto loc = lastDot.getLoc();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
Expand All @@ -1678,16 +1769,40 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
dotOp->erase();
}

hasDotWait0 = hasDotWait0 || hasSyncDot;

// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
for (unsigned resultIndex : resultNeedSync) {
Value result = forOp->getResult(resultIndex);
SmallVector<Type> resultTypes(resultNeedSync.size());
SmallVector<Value> yieldThenValues(resultNeedSync.size());
SmallVector<Value> yieldElseValues(resultNeedSync.size());
for (int i = 0; i < resultNeedSync.size(); ++i) {
resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType();
yieldThenValues[i] = forOp->getResult(resultNeedSync[i]);
yieldElseValues[i] = forOp->getResult(resultNeedSync[i]);
}
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
/*hasElse*/ true);
builder.setInsertionPointToStart(ifOp.thenBlock());
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
result.replaceAllUsesExcept(ifOp.getResult(i), dotWait);
yieldThenValues[i] = dotWait.getResult();
}
auto yieldOpThen = builder.create<scf::YieldOp>(loc, yieldThenValues);
builder.setInsertionPointToEnd(ifOp.elseBlock());
auto yieldOpElse = builder.create<scf::YieldOp>(loc, yieldElseValues);

// 3. potentially remove redundant dot_wait after dot_async if having mutiple
// DotOp in the loop
removeExtraWait(dotWait, hasDotWait0);
}

Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,
Expand Down
63 changes: 63 additions & 0 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,66 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'casual': casual,
'seq_par': seq_par}
) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]]


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)
92 changes: 92 additions & 0 deletions test/TritonGPU/pipeline-hopper-remove-wait.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// RUN: ENABLE_TMA=1 ENABLE_MMA_V3=1 triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s


#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @two_dependent_dot(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32} , %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32} , %arg3: f32 , %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} , %arg5: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32} , %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg18: i32 , %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} ) attributes {noinline = false} {
%cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%c0_i32 = arith.constant 0 : i32
%c64_i32 = arith.constant 64 : i32
%cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1>
%c1_i32 = arith.constant 1 : i32
%cst_4 = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c1_i64 = arith.constant 1 : i64
%c128_i64 = arith.constant 128 : i64
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = arith.divsi %2, %arg8 : i32
%4 = arith.extsi %arg21 : i32 to i64
%5 = arith.extsi %arg11 : i32 to i64
%6 = tt.make_tensor_ptr %arg1, [%c128_i64, %4], [%c1_i64, %5], [%c0_i32, %3] {order = array<i32: 0, 1>} : <tensor<128x64xf16, #blocked>, 1>
%7 = arith.extsi %arg14 : i32 to i64
%8 = tt.make_tensor_ptr %arg2, [%4, %c128_i64], [%7, %c1_i64], [%3, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x128xf16, #blocked1>, 1>
%9 = arith.muli %0, %c128_i32 : i32
%10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3>
%13 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%14 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%15 = tt.splat %9 : (i32) -> tensor<128xi32, #blocked3>
%16 = arith.addi %13, %10 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%17 = arith.addi %14, %11 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%18 = arith.addi %15, %12 : tensor<128xi32, #blocked3>
%19 = arith.mulf %arg3, %cst_4 : f32
%20 = tt.addptr %arg0, %2 : !tt.ptr<f16, 1>, i32
%21 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
%22 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
%23 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2>
%24 = arith.muli %21, %23 : tensor<128x1xi32, #blocked2>
%25 = tt.splat %20 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked2>
%26 = tt.addptr %25, %24 : tensor<128x1x!tt.ptr<f16, 1>, #blocked2>, tensor<128x1xi32, #blocked2>
%27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%28 = tt.expand_dims %27 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2>
%29 = tt.broadcast %26 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked2>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked2>
%30 = tt.broadcast %28 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2>
%31 = tt.addptr %29, %30 : tensor<128x128x!tt.ptr<f16, 1>, #blocked2>, tensor<128x128xi32, #blocked2>
%32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked2>
%33 = tt.splat %19 : (f32) -> tensor<128x128xf32, #blocked2>
%34 = arith.extf %32 : tensor<128x128xf16, #blocked2> to tensor<128x128xf32, #blocked2>
%35 = arith.mulf %34, %33 : tensor<128x128xf32, #blocked2>
%36 = arith.truncf %35 : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2>
%37 = arith.addi %0, %c1_i32 : i32
%38 = arith.muli %37, %c128_i32 : i32
%42:5 = scf.for %arg22 = %c0_i32 to %38 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>) : i32 {
%59 = tt.load %arg26 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<128x64xf16, #blocked>, 1> -> tensor<128x64xf16, #blocked4>
%60 = tt.load %arg27 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<64x128xf16, #blocked2>
%66 = triton_gpu.convert_layout %36 : (tensor<128x128xf16, #blocked2>) -> tensor<128x128xf16, #shared>
%67 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked4>) -> tensor<128x64xf16, #shared1>
%68 = tt.dot %66, %67, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared> * tensor<128x64xf16, #shared1> -> tensor<128x64xf32, #mma>
%81 = arith.truncf %68 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
%82 = triton_gpu.convert_layout %60 : (tensor<64x128xf16, #blocked2>) -> tensor<64x128xf16, #shared>
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
// CHECK-LABEL: triton_nvidia_gpu.dot_async
// CHECK-LABEL-NOT: triton_nvidia_gpu.dot_wait
%84 = tt.dot %83, %82, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma1>
%85 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%87 = arith.addf %85, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%88 = tt.advance %arg26, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #blocked>, 1>
%89 = tt.advance %arg27, [%c64_i32, %c0_i32] : <tensor<64x128xf16, #blocked1>, 1>
scf.yield %84, %87, %arg25, %88, %89 : tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<128x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>
}
%54 = arith.addi %3, %9 : i32
%55 = arith.extsi %arg17 : i32 to i64
%56 = tt.make_tensor_ptr %arg5, [%4, %c128_i64], [%55, %c1_i64], [%54, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x128xf16, #blocked>, 1>
%57 = arith.truncf %42 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1>
%58 = triton_gpu.convert_layout %57 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #blocked2>
tt.store %56, %58 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #blocked2>
tt.return
}
}

0 comments on commit 0469d5f

Please sign in to comment.