Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate getValuesFromDotOperandLayoutStruct function #20127

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Deduplicate getValuesFromDotOperandLayoutStruct function
PiperOrigin-RevId: 702710979
  • Loading branch information
Google-ML-Automation committed Dec 4, 2024
commit 955376ba86e97f84ad2f649006d3a08bbb22fb87
114 changes: 10 additions & 104 deletions xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <map>
Expand Down Expand Up @@ -76,6 +77,8 @@ using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ttn::OperandsAndConstraints;

// TODO: b/350928208 - Declare these functions in the header files of the
// corresponding C++ files and include them here instead of forward-declaring.
// The functions below are defined in AccelerateMatmul.cpp.
namespace mlir::triton::gpu {
SmallVector<unsigned, 3> getWarpsPerTile(
Expand All @@ -94,6 +97,13 @@ int64_t getSwizzlingFromLayout(const triton::gpu::SharedEncodingAttr &layout,
ttn::WGMMAEltType getMmaRetType(Value);
ttn::WGMMAEltType getMmaOperandType(Value, bool);

// The functions below are defined in MMAv2.cpp.
using ValueTableV2 = std::map<std::array<int, 3>, Value>;
ValueTableV2 getValuesFromDotOperandLayoutStruct(
const LLVMTypeConverter *typeConverter, Location loc,
ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter,
int repK, RankedTensorType type);

namespace mlir::triton::xla {
namespace {

Expand Down Expand Up @@ -540,115 +550,11 @@ struct SparseLocalLoadToLLVMPass
}
};

using ValueTableV2 = std::map<std::array<int, 3>, Value>;

constexpr int kContractingFactor = 2; // implied by N:M (2:4)
constexpr int kCore = 2; // number of core matrices per batch
constexpr int kCoreTile = kCore * kContractingFactor;

// ----- Ampere implementation.
// This replicates the logic in the MMAV2 implementation.
ValueTableV2 getValuesFromDotOperandLayoutStruct(
const LLVMTypeConverter *typeConverter, Location loc,
ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter,
int repK, RankedTensorType type) {
auto elems = unpackLLElements(loc, value, rewriter);
auto eltTy = typeConverter->convertType(type.getElementType());
int offset{};
ValueTableV2 vals;
auto bitwidth = eltTy.getIntOrFloatBitWidth();
auto numElemsPerVec = 32 / bitwidth;
auto vecTy = vec_ty(eltTy, numElemsPerVec);

auto packVec = [&](std::array<int, 3> dstIdx) {
Value vec = undef(vecTy);
for (auto i = 0; i < numElemsPerVec; ++i) {
vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i));
}
vals[dstIdx] = bitcast(vec, i32_ty);
offset += numElemsPerVec;
};

auto dot = cast<DotOperandEncodingAttr>(type.getEncoding());
auto kWidth = dot.getKWidth();
auto largeK = bitwidth * kWidth > 32;
if (largeK) {
// For layouts with a large K dimension, the original register layout needs
// to be divided into multiple MMAs, where each MMA has contiguous 32 bits
// along the K dimension per thread.
// Using kWidth = 8 and bitwidth = 2 as an example,
// we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the
// K dimension.
llvm::SmallVector<unsigned> si;

if (dot.getOpIdx() == 0) {
// Original register layout:
//
// [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23]
// [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31]
//
// Each element in the layout is a single bf16.
//
// To derive four independent MMA operations, a stride of 4 is applied to
// the original register layout:
//
// 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]]
// 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]]
// 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]]
// 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]]
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
for (size_t tile = 0; tile < 4; ++tile)
for (size_t e = 0; e < numElemsPerVec; ++e) {
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
}
} else {
// Original register layout:
//
// [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T
//
// A stride of 4 is applied to derive four independent MMA operations:
//
// 1st MMA: [[0, 1], [8, 9]]
// 2nd MMA: [[2, 3], [10, 11]]
// 3rd MMA: [[4, 5], [12, 13]]
// 4th MMA: [[6, 7], [14, 15]]
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
for (size_t tile = 0; tile < 2; ++tile)
for (size_t e = 0; e < numElemsPerVec; ++e) {
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
}
}

auto step = si.size();
SmallVector<Value> perm(step);
for (auto i = 0; i < elems.size() / step; ++i) {
for (auto j = 0; j < step; ++j) {
perm[j] = elems[i * step + si[j]];
}
std::copy(perm.begin(), perm.end(), elems.begin() + i * step);
}
}

if (dot.getOpIdx() == 0) {
for (auto b = 0; b < batch; ++b)
for (auto m = 0; m < repOuter; ++m)
for (auto k = 0; k < repK; ++k) {
packVec({b, 2 * m, 2 * k});
packVec({b, 2 * m + 1, 2 * k});
packVec({b, 2 * m, 2 * k + 1});
packVec({b, 2 * m + 1, 2 * k + 1});
}
} else {
for (auto b = 0; b < batch; ++b)
for (auto n = 0; n < repOuter; ++n)
for (auto k = 0; k < repK; ++k) {
packVec({b, n, 2 * k});
packVec({b, n, 2 * k + 1});
}
}
return vals;
}

std::string getMmaSpPtxInstruction(Type type) {
if (type.isF16()) {
return "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32";
Expand Down
9 changes: 8 additions & 1 deletion xla/service/gpu/tests/sparse_dot_to_llvm_ampere.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: sparse_dot_to_llvm_ampere
tt.func @sparse_dot_to_llvm_ampere(%A_dot: tensor<32x32xf16, #dot_operand_a>, %B_dot: tensor<64x32xf16, #dot_operand_b>, %meta_reg: tensor<32x4xi16, #dot_meta_enc>) {
// CHECK-COUNT-4: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
%acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%D = triton_xla.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mma>
tt.return
Expand Down
Loading