diff --git a/xla/service/gpu/fusions/transforms/lower_tensors.cc b/xla/service/gpu/fusions/transforms/lower_tensors.cc index 8602630302c3c9..87238edacec1d3 100644 --- a/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -31,13 +30,11 @@ limitations under the License. #include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -74,16 +71,12 @@ namespace { #define GEN_PASS_DEF_LOWERTENSORSPASS #include "xla/service/gpu/fusions/transforms/passes.h.inc" -using llvm::APFloat; -using llvm::ArrayRef; using mlir::failure; -using mlir::ImplicitLocOpBuilder; using mlir::Location; using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpBuilder; using mlir::Operation; -using mlir::OpRewritePattern; using mlir::success; using mlir::Type; using mlir::TypedValue; @@ -91,9 +84,6 @@ using mlir::TypeRange; using mlir::Value; using mlir::ValueRange; -namespace ma = ::mlir::arith; -namespace mc = ::mlir::complex; -namespace mm = ::mlir::math; namespace arith = ::mlir::arith; namespace scf = ::mlir::scf; namespace ml = ::mlir::LLVM; @@ -249,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { load, b.create(4, load.getType())); load = b.create( op.getType(), - b.create(is_low_nibble, load, high_value)); + b.create(is_low_nibble, load, high_value)); } rewriter.replaceOpWithNewOp( @@ -378,7 +368,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { body_builder.create( scalar_value, body_builder.create(4, ty))); - Value new_value = body_builder.create( + Value new_value = body_builder.create( is_low_nibble, low_updated, high_updated); body_builder.create(new_value); Value casted_result = b.create( @@ -1052,106 +1042,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { std::string gpu_arch_; }; -template -Value EvaluatePolynomial(ImplicitLocOpBuilder& b, Value arg, - ArrayRef coefficients) { - auto arg_type = mlir::cast(arg.getType()); - Value poly = - b.create(b.getFloatAttr(arg_type, coefficients[0])); - for (int i = 1; i < coefficients.size(); ++i) { - poly = b.create( - poly, arg, - b.create(b.getFloatAttr(arg_type, coefficients[i]))); - } - return poly; -}; - -struct RewriterExpm1Op : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i - // [handle inaccuracies when a and/or b are small] - // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i - // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i - mlir::LogicalResult matchAndRewrite( - mc::Expm1Op op, mlir::PatternRewriter& rewriter) const override { - auto type = op.getType(); - auto element_type = mlir::cast(type.getElementType()); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - Value real = b.create(op.getComplex()); - Value imag = b.create(op.getComplex()); - - Value zero = b.create(b.getFloatAttr(element_type, 0.0)); - Value one = b.create(b.getFloatAttr(element_type, 1.0)); - - Value expm1_real = b.create(real); - Value exp_real = b.create(expm1_real, one); - - Value sin_imag = b.create(imag); - Value cosm1_imag = EmitCosm1(imag, b); - Value cos_imag = b.create(cosm1_imag, one); - - Value real_result = b.create( - b.create(expm1_real, cos_imag), cosm1_imag); - - Value imag_is_zero = - b.create(ma::CmpFPredicate::OEQ, imag, zero); - Value imag_result = b.create( - imag_is_zero, zero, b.create(exp_real, sin_imag)); - - rewriter.replaceOpWithNewOp(op, type, real_result, - imag_result); - return mlir::success(); - } - - private: - Value EmitCosm1(Value arg, ImplicitLocOpBuilder& b) const { - auto arg_type = mlir::cast(arg.getType()); - auto negative_half = - b.create(b.getFloatAttr(arg_type, -0.5)); - auto negative_one = - b.create(b.getFloatAttr(arg_type, -1.0)); - - // Algorithm copied from cephes cosm1: - // cosm1(x) = -0.5 * x^2 + x^4 * P(x^2); - // that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1. - // - // This is an alternative algorithm - // cosm1(x) = -2 * sin(x/2)^2 - // that is only slightly less accurate around abs(x) == 0.1 but - // otherwise equivalent accuracy-wise compared to cephes cosm1. - // However, we are not using it because it is notably less - // performant than cephes cosm1. - - // TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy - // for large x values that are close to 2*pi*n where n is some integer. - static const std::array kCoeffs{ - 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, - 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, - 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, - 4.1666666666666666609054E-2, - }; - Value cos = b.create(arg); - Value for_large_x = b.create(cos, negative_one); - - Value arg_pow_2 = b.create(arg, arg); - Value arg_pow_4 = b.create(arg_pow_2, arg_pow_2); - Value poly = EvaluatePolynomial(b, arg_pow_2, ArrayRef(kCoeffs)); - - auto for_small_x = - b.create(b.create(arg_pow_4, poly), - b.create(negative_half, arg_pow_2)); - - // (pi/4)^2 is approximately 0.61685 - Value cond = b.create( - ma::CmpFPredicate::OGE, arg_pow_2, - b.create(b.getFloatAttr(arg_type, 0.61685))); - return b.create(cond, for_large_x, for_small_x); - } -}; - class LowerTensorsPass : public impl::LowerTensorsPassBase { public: explicit LowerTensorsPass(const LowerTensorsPassOptions& options) @@ -1162,7 +1052,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase { mlir::RewritePatternSet tensor_patterns(mlir_context); tensor_patterns.add(mlir_context, is_amd_gpu_, gpu_arch_); tensor_patterns - .add(mlir_context); if (mlir::failed(mlir::applyPatternsAndFoldGreedily( diff --git a/xla/service/gpu/fusions/transforms/passes.td b/xla/service/gpu/fusions/transforms/passes.td index 4d08552c0d2681..59e68e54ac9bf5 100644 --- a/xla/service/gpu/fusions/transforms/passes.td +++ b/xla/service/gpu/fusions/transforms/passes.td @@ -74,10 +74,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { let dependentDialects = [ "mlir::LLVM::LLVMDialect", - "mlir::complex::ComplexDialect", "mlir::func::FuncDialect", "mlir::gpu::GPUDialect", - "mlir::math::MathDialect", "mlir::scf::SCFDialect", "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", diff --git a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index 69377549340b73..a894c13dce1293 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -732,15 +732,3 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 { // CHECK: llvm.mlir.global private constant // CHECK-SAME: dense<[18, 48]> // CHECK-LABEL: @int4_constant - -// ----- - -func.func @complex_expm1_approx(%arg0: tensor<3xcomplex>, %i: index) - -> complex { - %extracted = tensor.extract %arg0[%i] : tensor<3xcomplex> - %expm1 = complex.expm1 %extracted : complex - return %expm1 : complex -} -// CHECK-LABEL: @complex_expm1_approx -// CHECK: math.expm1 -// CHECK-COUNT-6: math.fma \ No newline at end of file