diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 0196a21f4a69..2a6e7f281860 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -663,6 +663,17 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector +/// of i1. +class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> { +public: + using OpConversionPattern<XOrOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, return success(); } +LogicalResult +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (!isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType, + operands); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>, UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>, UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>, - SignedRemIOpPattern, XOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, |