Skip to content

Commit

Permalink
[XLA:GPU][Emitters] Remove the complex.expm1 approximation.
Browse files Browse the repository at this point in the history
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts d2e313c

PiperOrigin-RevId: 698191660
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Dec 3, 2024
1 parent e9947dd commit ce6d35e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 127 deletions.
116 changes: 3 additions & 113 deletions xla/service/gpu/fusions/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
Expand All @@ -31,13 +30,11 @@ limitations under the License.
#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -74,26 +71,19 @@ namespace {
#define GEN_PASS_DEF_LOWERTENSORSPASS
#include "xla/service/gpu/fusions/transforms/passes.h.inc"

using llvm::APFloat;
using llvm::ArrayRef;
using mlir::failure;
using mlir::ImplicitLocOpBuilder;
using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::Operation;
using mlir::OpRewritePattern;
using mlir::success;
using mlir::Type;
using mlir::TypedValue;
using mlir::TypeRange;
using mlir::Value;
using mlir::ValueRange;

namespace ma = ::mlir::arith;
namespace mc = ::mlir::complex;
namespace mm = ::mlir::math;
namespace arith = ::mlir::arith;
namespace scf = ::mlir::scf;
namespace ml = ::mlir::LLVM;
Expand Down Expand Up @@ -249,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
b.create<ma::SelectOp>(is_low_nibble, load, high_value));
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -378,7 +368,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
body_builder.create<mlir::arith::ShLIOp>(
scalar_value,
body_builder.create<mlir::arith::ConstantIntOp>(4, ty)));
Value new_value = body_builder.create<ma::SelectOp>(
Value new_value = body_builder.create<mlir::arith::SelectOp>(
is_low_nibble, low_updated, high_updated);
body_builder.create<mlir::scf::YieldOp>(new_value);
Value casted_result = b.create<mlir::UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -1052,106 +1042,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
std::string gpu_arch_;
};

template <typename FType>
Value EvaluatePolynomial(ImplicitLocOpBuilder& b, Value arg,
ArrayRef<FType> coefficients) {
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
Value poly =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[0]));
for (int i = 1; i < coefficients.size(); ++i) {
poly = b.create<mm::FmaOp>(
poly, arg,
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[i])));
}
return poly;
};

struct RewriterExpm1Op : public OpRewritePattern<mc::Expm1Op> {
using OpRewritePattern<mc::Expm1Op>::OpRewritePattern;

// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
// [handle inaccuracies when a and/or b are small]
// = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
// = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
mlir::LogicalResult matchAndRewrite(
mc::Expm1Op op, mlir::PatternRewriter& rewriter) const override {
auto type = op.getType();
auto element_type = mlir::cast<mlir::FloatType>(type.getElementType());

ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value real = b.create<mc::ReOp>(op.getComplex());
Value imag = b.create<mc::ImOp>(op.getComplex());

Value zero = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 0.0));
Value one = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 1.0));

Value expm1_real = b.create<mm::ExpM1Op>(real);
Value exp_real = b.create<ma::AddFOp>(expm1_real, one);

Value sin_imag = b.create<mm::SinOp>(imag);
Value cosm1_imag = EmitCosm1(imag, b);
Value cos_imag = b.create<ma::AddFOp>(cosm1_imag, one);

Value real_result = b.create<ma::AddFOp>(
b.create<ma::MulFOp>(expm1_real, cos_imag), cosm1_imag);

Value imag_is_zero =
b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, imag, zero);
Value imag_result = b.create<ma::SelectOp>(
imag_is_zero, zero, b.create<ma::MulFOp>(exp_real, sin_imag));

rewriter.replaceOpWithNewOp<mc::CreateOp>(op, type, real_result,
imag_result);
return mlir::success();
}

private:
Value EmitCosm1(Value arg, ImplicitLocOpBuilder& b) const {
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
auto negative_half =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -0.5));
auto negative_one =
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -1.0));

// Algorithm copied from cephes cosm1:
// cosm1(x) = -0.5 * x^2 + x^4 * P(x^2);
// that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1.
//
// This is an alternative algorithm
// cosm1(x) = -2 * sin(x/2)^2
// that is only slightly less accurate around abs(x) == 0.1 but
// otherwise equivalent accuracy-wise compared to cephes cosm1.
// However, we are not using it because it is notably less
// performant than cephes cosm1.

// TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy
// for large x values that are close to 2*pi*n where n is some integer.
static const std::array<double, 7> kCoeffs{
4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
4.1666666666666666609054E-2,
};
Value cos = b.create<mm::CosOp>(arg);
Value for_large_x = b.create<ma::AddFOp>(cos, negative_one);

Value arg_pow_2 = b.create<ma::MulFOp>(arg, arg);
Value arg_pow_4 = b.create<ma::MulFOp>(arg_pow_2, arg_pow_2);
Value poly = EvaluatePolynomial(b, arg_pow_2, ArrayRef<double>(kCoeffs));

auto for_small_x =
b.create<ma::AddFOp>(b.create<ma::MulFOp>(arg_pow_4, poly),
b.create<ma::MulFOp>(negative_half, arg_pow_2));

// (pi/4)^2 is approximately 0.61685
Value cond = b.create<ma::CmpFOp>(
ma::CmpFPredicate::OGE, arg_pow_2,
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, 0.61685)));
return b.create<ma::SelectOp>(cond, for_large_x, for_small_x);
}
};

class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
public:
explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
Expand All @@ -1162,7 +1052,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
mlir::RewritePatternSet tensor_patterns(mlir_context);
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
tensor_patterns
.add<RewriteAllocateShared, RewriterExpm1Op, RewriteNonScalarConstants,
.add<RewriteAllocateShared, RewriteNonScalarConstants,
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
Expand Down
2 changes: 0 additions & 2 deletions xla/service/gpu/fusions/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {

let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::complex::ComplexDialect",
"mlir::func::FuncDialect",
"mlir::gpu::GPUDialect",
"mlir::math::MathDialect",
"mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect",
"xla::gpu::XlaGpuDialect",
Expand Down
12 changes: 0 additions & 12 deletions xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -732,15 +732,3 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 {
// CHECK: llvm.mlir.global private constant
// CHECK-SAME: dense<[18, 48]>
// CHECK-LABEL: @int4_constant

// -----

func.func @complex_expm1_approx(%arg0: tensor<3xcomplex<f32>>, %i: index)
-> complex<f32> {
%extracted = tensor.extract %arg0[%i] : tensor<3xcomplex<f32>>
%expm1 = complex.expm1 %extracted : complex<f32>
return %expm1 : complex<f32>
}
// CHECK-LABEL: @complex_expm1_approx
// CHECK: math.expm1
// CHECK-COUNT-6: math.fma

0 comments on commit ce6d35e

Please sign in to comment.