diff options
Diffstat (limited to 'mlir/lib/Conversion/MathToLibm/MathToLibm.cpp')
-rw-r--r-- | mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp new file mode 100644 index 000000000000..8512432681c2 --- /dev/null +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -0,0 +1,147 @@ +//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToLibm/MathToLibm.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { +// Pattern to convert vector operations to scalar operations. This is needed as +// libm calls require scalars. +template <typename Op> +struct VecOpToScalarOp : public OpRewritePattern<Op> { +public: + using OpRewritePattern<Op>::OpRewritePattern; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; +// Pattern to convert scalar math operations to calls to libm functions. +// Additionally the libm function signatures are declared. +template <typename Op> +struct ScalarOpToLibmCall : public OpRewritePattern<Op> { +public: + using OpRewritePattern<Op>::OpRewritePattern; + ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit) + : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; + +private: + std::string floatFunc, doubleFunc; +}; +} // namespace + +template <typename Op> +LogicalResult +VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + auto opType = op.getType(); + auto loc = op.getLoc(); + auto vecType = opType.template dyn_cast<VectorType>(); + + if (!vecType) + return failure(); + if (!vecType.hasRank()) + return failure(); + auto shape = vecType.getShape(); + // TODO: support multidimensional vectors + if (shape.size() != 1) + return failure(); + + Value result = rewriter.create<ConstantOp>( + loc, DenseElementsAttr::get( + vecType, FloatAttr::get(vecType.getElementType(), 0.0))); + for (auto i = 0; i < shape.front(); ++i) { + SmallVector<Value> operands; + for (auto input : op->getOperands()) + operands.push_back( + rewriter.create<vector::ExtractElementOp>(loc, input, i)); + Value scalarOp = + rewriter.create<Op>(loc, vecType.getElementType(), operands); + result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i); + } + rewriter.replaceOp(op, {result}); + return success(); +} + +template <typename Op> +LogicalResult +ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, + PatternRewriter &rewriter) const { + auto module = op->template getParentOfType<ModuleOp>(); + auto type = op.getType(); + // TODO: Support Float16 by upcasting to Float32 + if (!type.template isa<Float32Type, Float64Type>()) + return failure(); + + auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto opFunc = module.template lookupSymbol<FuncOp>(name); + // Forward declare function if it hasn't already been + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto opFunctionTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = + rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); + opFunc.setPrivate(); + } + assert(opFunc.getType().template cast<FunctionType>().getResults() == + op->getResultTypes()); + assert(opFunc.getType().template cast<FunctionType>().getInputs() == + op->getOperandTypes()); + + rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands()); + + return success(); +} + +void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, + VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); + patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), + "atan2f", "atan2", benefit); + patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), + "expm1f", "expm1", benefit); + patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", + "tanh", benefit); +} + +namespace { +struct ConvertMathToLibmPass + : public ConvertMathToLibmBase<ConvertMathToLibmPass> { + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToLibmPass::runOnOperation() { + auto module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); + + ConversionTarget target(getContext()); + target.addLegalDialect<BuiltinDialect, StandardOpsDialect, + vector::VectorDialect>(); + target.addIllegalDialect<math::MathDialect>(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { + return std::make_unique<ConvertMathToLibmPass>(); +} |