Skip to content

Commit

Permalink
Deduplicate getValuesFromDotOperandLayoutStruct function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702661150
  • Loading branch information
Google-ML-Automation committed Dec 4, 2024
1 parent dad8d04 commit 5af0f1c
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 138 deletions.
4 changes: 1 addition & 3 deletions xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,7 @@ void AddLoweringPasses(mlir::OpPassManager& pm,
pm.addPass(CreateExpandFloatOpsPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createConvertSCFToCFPass());
bool is_amd = std::holds_alternative<se::RocmComputeCapability>(
device.gpu_compute_capability());
pm.addPass(CreateLowerToLLVMPass(is_amd));
pm.addPass(CreateLowerToLLVMPass(device));
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
}

Expand Down
24 changes: 13 additions & 11 deletions xla/service/gpu/fusions/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/transforms/passes.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand All @@ -69,7 +70,6 @@ namespace xla {
namespace gpu {
namespace {

#define GEN_PASS_DECL_LOWERTENSORSPASS
#define GEN_PASS_DEF_LOWERTENSORSPASS
#include "xla/service/gpu/fusions/transforms/passes.h.inc"

Expand Down Expand Up @@ -1055,16 +1055,20 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
: LowerTensorsPassBase(options) {}

void runOnOperation() override {
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
se::DeviceDescription device_description(device_info);
explicit LowerTensorsPass(const se::DeviceDescription& device_description)
: device_description_(device_description) {}

void runOnOperation() override {
if (!gpu_device_info_.empty()) {
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
device_description_ = se::DeviceDescription(device_info);
}
MLIRContext* mlir_context = &getContext();
mlir::RewritePatternSet tensor_patterns(mlir_context);

tensor_patterns.add<RewriteAtomicRMW>(mlir_context, &device_description);
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, &device_description_);
tensor_patterns
.add<RewriteAllocateShared, RewriteNonScalarConstants,
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
Expand Down Expand Up @@ -1115,6 +1119,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
signalPassFailure();
});
}
se::DeviceDescription device_description_;
};

} // namespace
Expand All @@ -1128,10 +1133,7 @@ std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(

std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
const se::DeviceDescription& device_description) {
std::string ascii_proto;
CHECK(tsl::protobuf::TextFormat::PrintToString(
device_description.ToGpuProto(), &ascii_proto));
return CreateLowerTensorsPass(ascii_proto);
return std::make_unique<LowerTensorsPass>(device_description);
}

} // namespace gpu
Expand Down
49 changes: 35 additions & 14 deletions xla/service/gpu/fusions/transforms/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <memory>
#include <utility>
#include <variant>

#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
Expand All @@ -41,21 +42,32 @@ limitations under the License.
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "xla/service/gpu/fusions/transforms/passes.h"
#include "xla/stream_executor/device_description.h"
#include "tsl/platform/protobuf.h" // IWYU pragma: keep

namespace xla {
namespace gpu {
namespace {

#define GEN_PASS_DEF_LOWERTOLLVMPASS
#define GEN_PASS_DECL_LOWERTOLLVMPASS
#include "xla/service/gpu/fusions/transforms/passes.h.inc"

namespace {

class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
public:
using LowerToLLVMPassBase::LowerToLLVMPassBase;
explicit LowerToLLVMPass(const LowerToLLVMPassOptions& options)
: LowerToLLVMPassBase(options) {}

explicit LowerToLLVMPass(const se::DeviceDescription& device_description)
: device_description_(device_description) {}

void runOnOperation() override {
if (!gpu_device_info_.empty()) {
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
device_description_ = se::DeviceDescription(device_info);
}
// Populate type conversions.
mlir::LowerToLLVMOptions llvm_opts(&getContext(),
mlir::DataLayout(getOperation()));
Expand All @@ -68,11 +80,14 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
mlir::arith::populateArithExpandOpsPatterns(patterns);
mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
patterns);
if (!this->is_amd_gpu_) {
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
} else {
if (std::holds_alternative<se::RocmComputeCapability>(
device_description_.gpu_compute_capability())) {
mlir::populateGpuToROCDLConversionPatterns(
type_converter, patterns, mlir::gpu::amd::Runtime::Unknown);
mlir::configureGpuToROCDLConversionLegality(target);
} else {
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
mlir::configureGpuToNVVMConversionLegality(target);
}
mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
Expand All @@ -81,11 +96,6 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);

// Setup target.
if (!this->is_amd_gpu_) {
mlir::configureGpuToNVVMConversionLegality(target);
} else {
mlir::configureGpuToROCDLConversionLegality(target);
}
target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
mlir::complex::ComplexDialect>();
target.addLegalOp<mlir::ModuleOp>();
Expand All @@ -107,12 +117,23 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
signalPassFailure();
}
}

private:
se::DeviceDescription device_description_;
};

} // namespace

std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(bool is_amd_gpu) {
return createLowerToLLVMPass(LowerToLLVMPassOptions{is_amd_gpu});
std::unique_ptr<::mlir::Pass> CreateLowerToLLVMPass(
const std::string& gpu_device_info) {
LowerToLLVMPassOptions options;
options.gpu_device_info_ = gpu_device_info;
return std::make_unique<LowerToLLVMPass>(options);
}

std::unique_ptr<::mlir::Pass> CreateLowerToLLVMPass(
const se::DeviceDescription& device_description) {
return std::make_unique<LowerToLLVMPass>(device_description);
}

} // namespace gpu
Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/fusions/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass();
std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
const std::string& gpu_device_info =
"cuda_compute_capability { major: 6 }");
const std::string& gpu_device_info = "");
std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
const se::DeviceDescription& device_description);
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(bool use_rocdl);
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(
const std::string& gpu_device_info = "");
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(
const se::DeviceDescription& device_description);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass(int64_t warp_size = 32);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/fusions/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,10 @@ def LowerToLLVMPass :
];

let options = [
Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
"True if AMD GPU.">,
Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"",
"Serialized stream_executor::GPUDeviceInfo proto.">,
];
let constructor = "CreateLowerToLLVMPass()";
}

def VectorizeLoadsAndStoresPass :
Expand Down
114 changes: 10 additions & 104 deletions xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <map>
Expand Down Expand Up @@ -76,6 +77,8 @@ using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ttn::OperandsAndConstraints;

// TODO: b/350928208 - Declare these functions in the header files of the
// corresponding C++ files and include them here instead of forward-declaring.
// The functions below are defined in AccelerateMatmul.cpp.
namespace mlir::triton::gpu {
SmallVector<unsigned, 3> getWarpsPerTile(
Expand All @@ -94,6 +97,13 @@ int64_t getSwizzlingFromLayout(const triton::gpu::SharedEncodingAttr &layout,
ttn::WGMMAEltType getMmaRetType(Value);
ttn::WGMMAEltType getMmaOperandType(Value, bool);

// The functions below are defined in MMAv2.cpp.
using ValueTableV2 = std::map<std::array<int, 3>, Value>;
ValueTableV2 getValuesFromDotOperandLayoutStruct(
const LLVMTypeConverter *typeConverter, Location loc,
ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter,
int repK, RankedTensorType type);

namespace mlir::triton::xla {
namespace {

Expand Down Expand Up @@ -540,115 +550,11 @@ struct SparseLocalLoadToLLVMPass
}
};

using ValueTableV2 = std::map<std::array<int, 3>, Value>;

constexpr int kContractingFactor = 2; // implied by N:M (2:4)
constexpr int kCore = 2; // number of core matrices per batch
constexpr int kCoreTile = kCore * kContractingFactor;

// ----- Ampere implementation.
// This replicates the logic in the MMAV2 implementation.
ValueTableV2 getValuesFromDotOperandLayoutStruct(
const LLVMTypeConverter *typeConverter, Location loc,
ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter,
int repK, RankedTensorType type) {
auto elems = unpackLLElements(loc, value, rewriter);
auto eltTy = typeConverter->convertType(type.getElementType());
int offset{};
ValueTableV2 vals;
auto bitwidth = eltTy.getIntOrFloatBitWidth();
auto numElemsPerVec = 32 / bitwidth;
auto vecTy = vec_ty(eltTy, numElemsPerVec);

auto packVec = [&](std::array<int, 3> dstIdx) {
Value vec = undef(vecTy);
for (auto i = 0; i < numElemsPerVec; ++i) {
vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i));
}
vals[dstIdx] = bitcast(vec, i32_ty);
offset += numElemsPerVec;
};

auto dot = cast<DotOperandEncodingAttr>(type.getEncoding());
auto kWidth = dot.getKWidth();
auto largeK = bitwidth * kWidth > 32;
if (largeK) {
// For layouts with a large K dimension, the original register layout needs
// to be divided into multiple MMAs, where each MMA has contiguous 32 bits
// along the K dimension per thread.
// Using kWidth = 8 and bitwidth = 2 as an example,
// we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the
// K dimension.
llvm::SmallVector<unsigned> si;

if (dot.getOpIdx() == 0) {
// Original register layout:
//
// [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23]
// [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31]
//
// Each element in the layout is a single bf16.
//
// To derive four independent MMA operations, a stride of 4 is applied to
// the original register layout:
//
// 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]]
// 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]]
// 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]]
// 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]]
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
for (size_t tile = 0; tile < 4; ++tile)
for (size_t e = 0; e < numElemsPerVec; ++e) {
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
}
} else {
// Original register layout:
//
// [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T
//
// A stride of 4 is applied to derive four independent MMA operations:
//
// 1st MMA: [[0, 1], [8, 9]]
// 2nd MMA: [[2, 3], [10, 11]]
// 3rd MMA: [[4, 5], [12, 13]]
// 4th MMA: [[6, 7], [14, 15]]
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
for (size_t tile = 0; tile < 2; ++tile)
for (size_t e = 0; e < numElemsPerVec; ++e) {
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
}
}

auto step = si.size();
SmallVector<Value> perm(step);
for (auto i = 0; i < elems.size() / step; ++i) {
for (auto j = 0; j < step; ++j) {
perm[j] = elems[i * step + si[j]];
}
std::copy(perm.begin(), perm.end(), elems.begin() + i * step);
}
}

if (dot.getOpIdx() == 0) {
for (auto b = 0; b < batch; ++b)
for (auto m = 0; m < repOuter; ++m)
for (auto k = 0; k < repK; ++k) {
packVec({b, 2 * m, 2 * k});
packVec({b, 2 * m + 1, 2 * k});
packVec({b, 2 * m, 2 * k + 1});
packVec({b, 2 * m + 1, 2 * k + 1});
}
} else {
for (auto b = 0; b < batch; ++b)
for (auto n = 0; n < repOuter; ++n)
for (auto k = 0; k < repK; ++k) {
packVec({b, n, 2 * k});
packVec({b, n, 2 * k + 1});
}
}
return vals;
}

std::string getMmaSpPtxInstruction(Type type) {
if (type.isF16()) {
return "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32";
Expand Down
9 changes: 8 additions & 1 deletion xla/service/gpu/tests/sparse_dot_to_llvm_ampere.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: sparse_dot_to_llvm_ampere
tt.func @sparse_dot_to_llvm_ampere(%A_dot: tensor<32x32xf16, #dot_operand_a>, %B_dot: tensor<64x32xf16, #dot_operand_b>, %meta_reg: tensor<32x4xi16, #dot_meta_enc>) {
// CHECK-COUNT-4: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
// CHECK-SAME: (f32, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32)>
%acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%D = triton_xla.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mma>
tt.return
Expand Down

0 comments on commit 5af0f1c

Please sign in to comment.