aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2021-04-13 10:18:34 +0200
committerTres Popp <tpopp@google.com>2021-04-20 11:38:55 +0200
commit34810e1b9c4554976d9d8249b18f48ff083b55fa (patch)
tree303914f4da5b8e677c4f1face06b4e7f03b09742
parent[Support] BinaryStreamReader.h - remove unnecessary <string> include. NFCI. (diff)
downloadllvm-project-34810e1b9c4554976d9d8249b18f48ff083b55fa.tar.gz
llvm-project-34810e1b9c4554976d9d8249b18f48ff083b55fa.tar.bz2
llvm-project-34810e1b9c4554976d9d8249b18f48ff083b55fa.zip
[mlir] Add patterns to lower Math operations to LLVM based libm calls.
Some Math operations do not have an equivalent in LLVM. In these cases, allow a low priority fallback of calling the libm functions. This is to give functionality and is not a performant option. Differential Revision: https://reviews.llvm.org/D100367
-rw-r--r--mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h26
-rw-r--r--mlir/include/mlir/Conversion/Passes.h1
-rw-r--r--mlir/include/mlir/Conversion/Passes.td13
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToLibm/CMakeLists.txt16
-rw-r--r--mlir/lib/Conversion/MathToLibm/MathToLibm.cpp147
-rw-r--r--mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir73
7 files changed, 277 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
new file mode 100644
index 000000000000..9e7aa1a0f52a
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -0,0 +1,26 @@
+//===- MathToLibm.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
+#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Math to Libm calls.
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+
+/// Create a pass to convert Math operations to libm calls.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 21e604eabecd..64de7c962bee 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -20,6 +20,7 @@
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6eb5abdefe55..eb940d341404 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -229,6 +229,19 @@ def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
}
//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
+ let summary = "Convert Math dialect to libm calls";
+ let description = [{
+ This pass converts supported Math ops to libm calls.
+ }];
+ let constructor = "mlir::createConvertMathToLibmPass()";
+ let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
+}
+
+//===----------------------------------------------------------------------===//
// OpenMPToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 4f6d4a27ecca..60dbab0a0443 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(GPUToVulkan)
add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
+add_subdirectory(MathToLibm)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)
diff --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt
new file mode 100644
index 000000000000..cd43a11d30d5
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRMathToLibm
+ MathToLibm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLibm
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRMath
+ MLIRStandardOpsTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
new file mode 100644
index 000000000000..8512432681c2
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -0,0 +1,147 @@
+//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert vector operations to scalar operations. This is needed as
+// libm calls require scalars.
+template <typename Op>
+struct VecOpToScalarOp : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
+// Pattern to convert scalar math operations to calls to libm functions.
+// Additionally the libm function signatures are declared.
+template <typename Op>
+struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc){};
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+ auto opType = op.getType();
+ auto loc = op.getLoc();
+ auto vecType = opType.template dyn_cast<VectorType>();
+
+ if (!vecType)
+ return failure();
+ if (!vecType.hasRank())
+ return failure();
+ auto shape = vecType.getShape();
+ // TODO: support multidimensional vectors
+ if (shape.size() != 1)
+ return failure();
+
+ Value result = rewriter.create<ConstantOp>(
+ loc, DenseElementsAttr::get(
+ vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
+ for (auto i = 0; i < shape.front(); ++i) {
+ SmallVector<Value> operands;
+ for (auto input : op->getOperands())
+ operands.push_back(
+ rewriter.create<vector::ExtractElementOp>(loc, input, i));
+ Value scalarOp =
+ rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
+ }
+ rewriter.replaceOp(op, {result});
+ return success();
+}
+
+template <typename Op>
+LogicalResult
+ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const {
+ auto module = op->template getParentOfType<ModuleOp>();
+ auto type = op.getType();
+ // TODO: Support Float16 by upcasting to Float32
+ if (!type.template isa<Float32Type, Float64Type>())
+ return failure();
+
+ auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+ auto opFunc = module.template lookupSymbol<FuncOp>(name);
+ // Forward declare function if it hasn't already been
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ auto opFunctionTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc =
+ rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
+ opFunc.setPrivate();
+ }
+ assert(opFunc.getType().template cast<FunctionType>().getResults() ==
+ op->getResultTypes());
+ assert(opFunc.getType().template cast<FunctionType>().getInputs() ==
+ op->getOperandTypes());
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands());
+
+ return success();
+}
+
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+ VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
+ patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
+ "atan2f", "atan2", benefit);
+ patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
+ "expm1f", "expm1", benefit);
+ patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
+ "tanh", benefit);
+}
+
+namespace {
+struct ConvertMathToLibmPass
+ : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLibmPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, StandardOpsDialect,
+ vector::VectorDialect>();
+ target.addIllegalDialect<math::MathDialect>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
+ return std::make_unique<ConvertMathToLibmPass>();
+}
diff --git a/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir
new file mode 100644
index 000000000000..7c8d8e7136bb
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
+
+// CHECK-DAG: @expm1(f64) -> f64
+// CHECK-DAG: @expm1f(f32) -> f32
+// CHECK-DAG: @atan2(f64, f64) -> f64
+// CHECK-DAG: @atan2f(f32, f32) -> f32
+// CHECK-DAG: @tanh(f64) -> f64
+// CHECK-DAG: @tanhf(f32) -> f32
+
+// CHECK-LABEL: func @tanh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.tanh %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.tanh %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+
+// CHECK-LABEL: func @atan2_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
+ %float_result = math.atan2 %float, %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
+ %double_result = math.atan2 %double, %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+// CHECK-LABEL: func @expm1_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.expm1 %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.expm1 %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+ %float_result = math.expm1 %float : vector<2xf32>
+ %double_result = math.expm1 %double : vector<2xf64>
+ return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+// CHECK-LABEL: func @expm1_vec_caller(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+// CHECK: %[[CVF:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[CVD:.*]] = constant dense<0.000000e+00> : vector<2xf64>
+// CHECK: %[[C0:.*]] = constant 0 : i32
+// CHECK: %[[C1:.*]] = constant 1 : i32
+// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
+// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
+// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: }
+