diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 86d2289887fae..295793cb08242 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1582,7 +1582,8 @@ struct PipelinePass : public TritonGPUPipelineBase { void asyncLaunchDots(scf::ForOp forOp); void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info, int numStages); - + bool selfDepend(tt::DotOp op, scf::ForOp forOp, Operation **firstUse); + void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, bool hasDotWait0); ConsumerReleaseMap consumerReleaseMap; }; @@ -1612,13 +1613,89 @@ void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, } } +bool PipelinePass::selfDepend(tt::DotOp dotOp, scf::ForOp forOp, + Operation **firstUse) { + std::function dependOn = + [&dependOn](Value v, int argId, scf::ForOp forOp) { + auto op = v.getDefiningOp(); + if (isa(v)) { + auto iterArgs = forOp.getRegionIterArgs(); + auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); + if (iter != iterArgs.end()) + return std::distance(iterArgs.begin(), iter) == argId; + } else { + if (!op) + return false; + for (auto operand : op->getOperands()) { + if (dependOn(operand, argId, forOp)) + return true; + } + } + return false; + }; + auto result = dotOp.getResult(); + auto yieldOp = forOp.getBody()->getTerminator(); + int argIdx = -1; + auto iter = std::find(yieldOp->getOperands().begin(), + yieldOp->getOperands().end(), result); + if (iter != yieldOp->getOperands().end()) + argIdx = std::distance(yieldOp->getOperands().begin(), iter); + if (argIdx == -1) + return false; + for (auto operand : dotOp.getOperands()) { + if (dependOn(operand, argIdx, forOp)) { + auto iterArgs = forOp.getRegionIterArgs(); + *firstUse = iterArgs[argIdx].use_begin().getUser(); + return true; + } + } + return false; +} + +void PipelinePass::removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, + bool hasDotWait0) { + if (hasDotWait0) { + for (auto &item : consumerReleaseMap) { + auto &m = item.second.consumerStageMap; + if (m.count(dotWaitOp)) { + m.erase(dotWaitOp); + } + } + dotWaitOp->erase(); + } +} + void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { Block *loop = forOp.getBody(); - + auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { + if (!op) + return -1l; + auto lastOp = op; + while (op->getBlock()->getParentOp() != forOp) { + lastOp = op; + op = op->getBlock()->getParentOp(); + } + return std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); + }; /// XXX(Keren): Clean up the following duplicate code with checkDotOp /// dots to be pipelined + bool hasSyncDot = false; + bool hasDotWait0 = false; + SmallVector allDots; SmallVector dots; SmallVector resultNeedSync; + for (Operation &op : *loop) { + if (auto dotWaitOp = dyn_cast(&op)) { + auto attr = dotWaitOp->getAttrOfType("pendings"); + auto pendingCount = attr.getInt(); + if (pendingCount == 0) + hasDotWait0 = true; + } + if (auto dotOp = dyn_cast(&op)) { + allDots.push_back(dotOp); + } + } for (Operation &op : *loop) { if (auto dotOp = dyn_cast(&op)) { auto resTy = dotOp.getResult().getType().dyn_cast(); @@ -1635,9 +1712,22 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { if (!isa(*dot.getUsers().begin())) valid = false; - // C should be a block argument - auto CArg = dotOp.getOperand(2).dyn_cast(); - if (!CArg || !CArg.hasOneUse()) + Operation *firstUse = nullptr; + selfDepend(dotOp, forOp, &firstUse); + bool selfDirectDepend = (dotOp == firstUse); + for (auto tempInAll : allDots) { + auto iter = std::find(dots.begin(), dots.end(), tempInAll); + if (iter != dots.end()) + continue; + auto db = getBlockNumInFor(tempInAll, forOp); + auto fb = getBlockNumInFor(firstUse, forOp); + if (db < fb || + (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) + hasSyncDot = true; + } + auto CArg = dotOp.getOperand(2); + if (!(selfDirectDepend || (!selfDirectDepend && hasSyncDot)) || + !CArg.hasOneUse()) valid = false; if (valid) { @@ -1662,6 +1752,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { // TODO: merge this with the rest of the pipelining transformation and look at // a better representation for async dots. tt::DotOp lastDot = dots.back(); + auto loc = lastDot.getLoc(); builder.setInsertionPointAfter(lastDot); auto dotWait = builder.create( lastDot.getLoc(), lastDot.getResult(), dots.size()); @@ -1678,16 +1769,40 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { dotOp->erase(); } + hasDotWait0 = hasDotWait0 || hasSyncDot; + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. builder.setInsertionPointAfter(forOp); - for (unsigned resultIndex : resultNeedSync) { - Value result = forOp->getResult(resultIndex); + SmallVector resultTypes(resultNeedSync.size()); + SmallVector yieldThenValues(resultNeedSync.size()); + SmallVector yieldElseValues(resultNeedSync.size()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType(); + yieldThenValues[i] = forOp->getResult(resultNeedSync[i]); + yieldElseValues[i] = forOp->getResult(resultNeedSync[i]); + } + Value loopNotEmpty = builder.create( + loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), + forOp.getUpperBound()); + auto ifOp = builder.create(loc, resultTypes, loopNotEmpty, + /*hasElse*/ true); + builder.setInsertionPointToStart(ifOp.thenBlock()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + Value result = forOp->getResult(resultNeedSync[i]); if (result.use_empty()) continue; auto dotWait = builder.create(forOp.getLoc(), result, 0); - result.replaceAllUsesExcept(dotWait.getResult(), dotWait); + result.replaceAllUsesExcept(ifOp.getResult(i), dotWait); + yieldThenValues[i] = dotWait.getResult(); } + auto yieldOpThen = builder.create(loc, yieldThenValues); + builder.setInsertionPointToEnd(ifOp.elseBlock()); + auto yieldOpElse = builder.create(loc, yieldElseValues); + + // 3. potentially remove redundant dot_wait after dot_async if having mutiple + // DotOp in the loop + removeExtraWait(dotWait, hasDotWait0); } Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index fabf454dc1759..18c254d9d19cc 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,13 +5,15 @@ import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), - (2, 4, 512, 32), - (2, 4, 512, 64), - (2, 4, 512, 128)]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('seq_par', [True, False]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128) +]) +@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('seq_par', [False]) def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): import os enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() @@ -53,3 +55,65 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par} +) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + sm_scale = 1.3 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if provider == "triton": + fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros( + (BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir new file mode 100644 index 0000000000000..777e36546b82c --- /dev/null +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -0,0 +1,92 @@ +// RUN: ENABLE_TMA=1 ENABLE_MMA_V3=1 triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: f32 , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: !tt.ptr {tt.divisibility = 16 : i32} , %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg18: i32 , %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %cst_4 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.divsi %2, %arg8 : i32 + %4 = arith.extsi %arg21 : i32 to i64 + %5 = arith.extsi %arg11 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%c128_i64, %4], [%c1_i64, %5], [%c0_i32, %3] {order = array} : , 1> + %7 = arith.extsi %arg14 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%4, %c128_i64], [%7, %c1_i64], [%3, %c0_i32] {order = array} : , 1> + %9 = arith.muli %0, %c128_i32 : i32 + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %13 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %14 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.splat %9 : (i32) -> tensor<128xi32, #blocked3> + %16 = arith.addi %13, %10 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %17 = arith.addi %14, %11 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %18 = arith.addi %15, %12 : tensor<128xi32, #blocked3> + %19 = arith.mulf %arg3, %cst_4 : f32 + %20 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %21 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %22 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma> + %23 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %24 = arith.muli %21, %23 : tensor<128x1xi32, #blocked2> + %25 = tt.splat %20 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked2> + %26 = tt.addptr %25, %24 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %28 = tt.expand_dims %27 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %29 = tt.broadcast %26 : (tensor<128x1x!tt.ptr, #blocked2>) -> tensor<128x128x!tt.ptr, #blocked2> + %30 = tt.broadcast %28 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %31 = tt.addptr %29, %30 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked2> + %33 = tt.splat %19 : (f32) -> tensor<128x128xf32, #blocked2> + %34 = arith.extf %32 : tensor<128x128xf16, #blocked2> to tensor<128x128xf32, #blocked2> + %35 = arith.mulf %34, %33 : tensor<128x128xf32, #blocked2> + %36 = arith.truncf %35 : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = arith.addi %0, %c1_i32 : i32 + %38 = arith.muli %37, %c128_i32 : i32 + %42:5 = scf.for %arg22 = %c0_i32 to %38 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %59 = tt.load %arg26 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<128x64xf16, #blocked4> + %60 = tt.load %arg27 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked2> + %66 = triton_gpu.convert_layout %36 : (tensor<128x128xf16, #blocked2>) -> tensor<128x128xf16, #shared> + %67 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked4>) -> tensor<128x64xf16, #shared1> + %68 = tt.dot %66, %67, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared> * tensor<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %81 = arith.truncf %68 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %82 = triton_gpu.convert_layout %60 : (tensor<64x128xf16, #blocked2>) -> tensor<64x128xf16, #shared> + %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + // CHECK-LABEL: triton_nvidia_gpu.dot_async + // CHECK-LABEL-NOT: triton_nvidia_gpu.dot_wait + %84 = tt.dot %83, %82, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %85 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %87 = arith.addf %85, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %88 = tt.advance %arg26, [%c0_i32, %c64_i32] : , 1> + %89 = tt.advance %arg27, [%c64_i32, %c0_i32] : , 1> + scf.yield %84, %87, %arg25, %88, %89 : tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr, 1>, !tt.ptr, 1> + } + %54 = arith.addi %3, %9 : i32 + %55 = arith.extsi %arg17 : i32 to i64 + %56 = tt.make_tensor_ptr %arg5, [%4, %c128_i64], [%55, %c1_i64], [%54, %c0_i32] {order = array} : , 1> + %57 = arith.truncf %42 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> + %58 = triton_gpu.convert_layout %57 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #blocked2> + tt.store %56, %58 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128x128xf16, #blocked2> + tt.return + } +}