aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MathToLibm/MathToLibm.cpp')
-rw-r--r--mlir/lib/Conversion/MathToLibm/MathToLibm.cpp147
1 files changed, 147 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
new file mode 100644
index 000000000000..8512432681c2
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -0,0 +1,147 @@
+//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert vector operations to scalar operations. This is needed as
+// libm calls require scalars.
+template <typename Op>
+struct VecOpToScalarOp : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
+// Pattern to convert scalar math operations to calls to libm functions.
+// Additionally the libm function signatures are declared.
+template <typename Op>
+struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc){};
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+ auto opType = op.getType();
+ auto loc = op.getLoc();
+ auto vecType = opType.template dyn_cast<VectorType>();
+
+ if (!vecType)
+ return failure();
+ if (!vecType.hasRank())
+ return failure();
+ auto shape = vecType.getShape();
+ // TODO: support multidimensional vectors
+ if (shape.size() != 1)
+ return failure();
+
+ Value result = rewriter.create<ConstantOp>(
+ loc, DenseElementsAttr::get(
+ vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
+ for (auto i = 0; i < shape.front(); ++i) {
+ SmallVector<Value> operands;
+ for (auto input : op->getOperands())
+ operands.push_back(
+ rewriter.create<vector::ExtractElementOp>(loc, input, i));
+ Value scalarOp =
+ rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
+ }
+ rewriter.replaceOp(op, {result});
+ return success();
+}
+
+template <typename Op>
+LogicalResult
+ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const {
+ auto module = op->template getParentOfType<ModuleOp>();
+ auto type = op.getType();
+ // TODO: Support Float16 by upcasting to Float32
+ if (!type.template isa<Float32Type, Float64Type>())
+ return failure();
+
+ auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+ auto opFunc = module.template lookupSymbol<FuncOp>(name);
+ // Forward declare function if it hasn't already been
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ auto opFunctionTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc =
+ rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
+ opFunc.setPrivate();
+ }
+ assert(opFunc.getType().template cast<FunctionType>().getResults() ==
+ op->getResultTypes());
+ assert(opFunc.getType().template cast<FunctionType>().getInputs() ==
+ op->getOperandTypes());
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands());
+
+ return success();
+}
+
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+ VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
+ patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
+ "atan2f", "atan2", benefit);
+ patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
+ "expm1f", "expm1", benefit);
+ patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
+ "tanh", benefit);
+}
+
+namespace {
+struct ConvertMathToLibmPass
+ : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLibmPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, StandardOpsDialect,
+ vector::VectorDialect>();
+ target.addIllegalDialect<math::MathDialect>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
+ return std::make_unique<ConvertMathToLibmPass>();
+}