aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHanhan Wang <hanchung@google.com>2021-04-20 07:34:32 -0700
committerHanhan Wang <hanchung@google.com>2021-04-20 07:35:20 -0700
commit7b7df8e85eec445389e4b07915f16aa18332719d (patch)
tree7f28e4dbeae6d44b94fa436ac3e2221622728617
parent[gn build] reformat all gn files (diff)
downloadllvm-project-main.tar.gz
llvm-project-main.tar.bz2
llvm-project-main.zip
[mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-Vmain
std.xor ops on bool are lowered to spv.LogicalNotEqual. For Boolean values, xor and not-equal are the same thing. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D100817
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp29
-rw-r--r--mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir4
2 files changed, 32 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 0196a21f4a69..2a6e7f281860 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -663,6 +663,17 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector
+/// of i1.
+class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> {
+public:
+ using OpConversionPattern<XOrOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
return success();
}
+LogicalResult
+BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ assert(operands.size() == 2);
+
+ if (!isBoolScalarOrVector(operands.front().getType()))
+ return failure();
+
+ auto dstType = getTypeConverter()->convertType(xorOp.getType());
+ if (!dstType)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
+ operands);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
- SignedRemIOpPattern, XOrOpPattern,
+ SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
// Comparison patterns
BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 0148a0731dc9..fe769482c787 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) {
%0 = and %arg0, %arg1 : i1
// CHECK: spv.LogicalOr
%1 = or %arg0, %arg1 : i1
+ // CHECK: spv.LogicalNotEqual
+ %2 = xor %arg0, %arg1 : i1
return
}
@@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
%0 = and %arg0, %arg1 : vector<4xi1>
// CHECK: spv.LogicalOr
%1 = or %arg0, %arg1 : vector<4xi1>
+ // CHECK: spv.LogicalNotEqual
+ %2 = xor %arg0, %arg1 : vector<4xi1>
return
}