Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU][Emitters] Remove the complex.expm1 approximation. #19519

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 3 additions & 113 deletions xla/service/gpu/fusions/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
Expand All @@ -31,13 +30,11 @@ limitations under the License.
#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -74,26 +71,19 @@ namespace {
#define GEN_PASS_DEF_LOWERTENSORSPASS
#include "xla/service/gpu/fusions/transforms/passes.h.inc"

using llvm::APFloat;
using llvm::ArrayRef;
using mlir::failure;
using mlir::ImplicitLocOpBuilder;
using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::Operation;
using mlir::OpRewritePattern;
using mlir::success;
using mlir::Type;
using mlir::TypedValue;
using mlir::TypeRange;
using mlir::Value;
using mlir::ValueRange;

namespace ma = ::mlir::arith;
namespace mc = ::mlir::complex;
namespace mm = ::mlir::math;
namespace arith = ::mlir::arith;
namespace scf = ::mlir::scf;
namespace ml = ::mlir::LLVM;
Expand Down Expand Up @@ -249,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
b.create<ma::SelectOp>(is_low_nibble, load, high_value));
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -378,7 +368,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
body_builder.create<mlir::arith::ShLIOp>(
scalar_value,
body_builder.create<mlir::arith::ConstantIntOp>(4, ty)));
Value new_value = body_builder.create<ma::SelectOp>(
Value new_value = body_builder.create<mlir::arith::SelectOp>(
is_low_nibble, low_updated, high_updated);
body_builder.create<mlir::scf::YieldOp>(new_value);
Value casted_result = b.create<mlir::UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -1052,106 +1042,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
std::string gpu_arch_;
};

template <typename FType>
Value EvaluatePolynomial(ImplicitLocOpBuilder& b, Value arg,
ArrayRef<FType> coefficients) {
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
Value poly =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[0]));
for (int i = 1; i < coefficients.size(); ++i) {
poly = b.create<mm::FmaOp>(
poly, arg,
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[i])));
}
return poly;
};

struct RewriterExpm1Op : public OpRewritePattern<mc::Expm1Op> {
using OpRewritePattern<mc::Expm1Op>::OpRewritePattern;

// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
// [handle inaccuracies when a and/or b are small]
// = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
// = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
mlir::LogicalResult matchAndRewrite(
mc::Expm1Op op, mlir::PatternRewriter& rewriter) const override {
auto type = op.getType();
auto element_type = mlir::cast<mlir::FloatType>(type.getElementType());

ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value real = b.create<mc::ReOp>(op.getComplex());
Value imag = b.create<mc::ImOp>(op.getComplex());

Value zero = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 0.0));
Value one = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 1.0));

Value expm1_real = b.create<mm::ExpM1Op>(real);
Value exp_real = b.create<ma::AddFOp>(expm1_real, one);

Value sin_imag = b.create<mm::SinOp>(imag);
Value cosm1_imag = EmitCosm1(imag, b);
Value cos_imag = b.create<ma::AddFOp>(cosm1_imag, one);

Value real_result = b.create<ma::AddFOp>(
b.create<ma::MulFOp>(expm1_real, cos_imag), cosm1_imag);

Value imag_is_zero =
b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, imag, zero);
Value imag_result = b.create<ma::SelectOp>(
imag_is_zero, zero, b.create<ma::MulFOp>(exp_real, sin_imag));

rewriter.replaceOpWithNewOp<mc::CreateOp>(op, type, real_result,
imag_result);
return mlir::success();
}

private:
Value EmitCosm1(Value arg, ImplicitLocOpBuilder& b) const {
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
auto negative_half =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -0.5));
auto negative_one =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -1.0));

// Algorithm copied from cephes cosm1:
// cosm1(x) = -0.5 * x^2 + x^4 * P(x^2);
// that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1.
//
// This is an alternative algorithm
// cosm1(x) = -2 * sin(x/2)^2
// that is only slightly less accurate around abs(x) == 0.1 but
// otherwise equivalent accuracy-wise compared to cephes cosm1.
// However, we are not using it because it is notably less
// performant than cephes cosm1.

// TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy
// for large x values that are close to 2*pi*n where n is some integer.
static const std::array<double, 7> kCoeffs{
4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
4.1666666666666666609054E-2,
};
Value cos = b.create<mm::CosOp>(arg);
Value for_large_x = b.create<ma::AddFOp>(cos, negative_one);

Value arg_pow_2 = b.create<ma::MulFOp>(arg, arg);
Value arg_pow_4 = b.create<ma::MulFOp>(arg_pow_2, arg_pow_2);
Value poly = EvaluatePolynomial(b, arg_pow_2, ArrayRef<double>(kCoeffs));

auto for_small_x =
b.create<ma::AddFOp>(b.create<ma::MulFOp>(arg_pow_4, poly),
b.create<ma::MulFOp>(negative_half, arg_pow_2));

// (pi/4)^2 is approximately 0.61685
Value cond = b.create<ma::CmpFOp>(
ma::CmpFPredicate::OGE, arg_pow_2,
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, 0.61685)));
return b.create<ma::SelectOp>(cond, for_large_x, for_small_x);
}
};

class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
public:
explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
Expand All @@ -1162,7 +1052,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
mlir::RewritePatternSet tensor_patterns(mlir_context);
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
tensor_patterns
.add<RewriteAllocateShared, RewriterExpm1Op, RewriteNonScalarConstants,
.add<RewriteAllocateShared, RewriteNonScalarConstants,
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
Expand Down
2 changes: 0 additions & 2 deletions xla/service/gpu/fusions/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {

let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::complex::ComplexDialect",
"mlir::func::FuncDialect",
"mlir::gpu::GPUDialect",
"mlir::math::MathDialect",
"mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect",
"xla::gpu::XlaGpuDialect",
Expand Down
12 changes: 0 additions & 12 deletions xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -732,15 +732,3 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 {
// CHECK: llvm.mlir.global private constant
// CHECK-SAME: dense<[18, 48]>
// CHECK-LABEL: @int4_constant

// -----

func.func @complex_expm1_approx(%arg0: tensor<3xcomplex<f32>>, %i: index)
-> complex<f32> {
%extracted = tensor.extract %arg0[%i] : tensor<3xcomplex<f32>>
%expm1 = complex.expm1 %extracted : complex<f32>
return %expm1 : complex<f32>
}
// CHECK-LABEL: @complex_expm1_approx
// CHECK: math.expm1
// CHECK-COUNT-6: math.fma
Loading