aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <nicolas.vasilache@gmail.com>2021-02-19 09:33:56 +0000
committerNicolas Vasilache <nicolas.vasilache@gmail.com>2021-02-19 09:38:33 +0000
commitd12fa33d736d60d419f86b4ec5f3e77e602d4b1e (patch)
tree429c4e891e59fd050727d0409f8707e4cbd07455
parent[docs] Fix the GlobalISel/GenericOpcode.rst (diff)
downloadllvm-project-d12fa33d736d60d419f86b4ec5f3e77e602d4b1e.tar.gz
llvm-project-d12fa33d736d60d419f86b4ec5f3e77e602d4b1e.tar.bz2
llvm-project-d12fa33d736d60d419f86b4ec5f3e77e602d4b1e.zip
[mlir] Add a TensorLoadToMemref canonicalization
A folder of `tensor_load + tensor_to_memref` exists but it only applies when source and destination memref types are the same. This revision adds a canonicalize `tensor_load + tensor_to_memref` to `memref_cast` when type mismatches prevent folding to kick in. Differential Revision: https://reviews.llvm.org/D97038
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp25
-rw-r--r--mlir/test/Dialect/Standard/canonicalize.mlir52
2 files changed, 75 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 084d3fdfb2bf..046033cc7f9d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3838,11 +3838,34 @@ struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
return success();
}
};
+
+/// Canonicalize tensor_load + tensor_to_memref to memref_cast when type
+/// mismatches prevent `TensorToMemrefOp::fold` to kick in.
+struct TensorLoadToMemref : public OpRewritePattern<TensorToMemrefOp> {
+ using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+ PatternRewriter &rewriter) const final {
+ auto tensorLoad = tensorToMemRef.tensor().getDefiningOp<TensorLoadOp>();
+ // Bail unless we have a tensor_load + tensor_to_memref with different
+ // types. `TensorToMemrefOp::fold` handles the same type case.
+ if (!tensorLoad ||
+ tensorLoad.memref().getType() == tensorToMemRef.getType())
+ return failure();
+ // If types are not cast-compatible, bail.
+ if (!MemRefCastOp::areCastCompatible(tensorLoad.memref().getType(),
+ tensorToMemRef.getType()))
+ return failure();
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(
+ tensorToMemRef, tensorToMemRef.getType(), tensorLoad.memref());
+ return success();
+ }
+};
} // namespace
void TensorToMemrefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<TensorCastToMemref>(context);
+ results.insert<TensorCastToMemref, TensorLoadToMemref>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 5c437ae3dda4..ff5ca24f7587 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
+
+// -----
// Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t
// CHECK-LABEL: func @tensor_load_of_tensor_to_memref(
@@ -10,6 +12,8 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %1 : tensor<?xf32>
}
+// -----
+
// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
// CHECK-LABEL: func @tensor_to_memref_of_tensor_load(
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
@@ -20,7 +24,11 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
return %1 : memref<?xf32>
}
+// -----
+
// Test case: If the memrefs are not the same type, don't fold them.
+// Test case: If the memrefs are not cast-compatible (e.g. different address space),
+// don't canonicalize them either.
// CHECK-LABEL: func @no_fold_tensor_to_memref_of_tensor_load(
// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) -> memref<?xf32, 7> {
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
@@ -32,6 +40,28 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
return %1 : memref<?xf32, 7>
}
+// -----
+
+// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
+// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// Test case: If the memrefs are cast-compatible, canonicalize.
+// CHECK-LABEL: func @canonicalize_tensor_to_memref_of_tensor_load(
+// CHECK-SAME: %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>) -> memref<?xf32, #[[$OFF_UNK]]> {
+// CHECK-NOT: tensor_load
+// CHECK-NOT: tensor_to_memref
+// CHECK: %[[R:.*]] = memref_cast %[[M]] : memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
+// CHECK: return %[[R]]
+func @canonicalize_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
+ -> memref<?xf32, offset: ?, strides: [1]>
+{
+ %0 = tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
+ %1 = tensor_to_memref %0 : memref<?xf32, offset: ?, strides: [1]>
+ return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+
+// -----
+
// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
// CHECK-LABEL: func @dim_of_tensor_load(
// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
@@ -45,6 +75,8 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
return %1 : index
}
+// -----
+
// Test case: Folding of load(tensor_to_memref(%v, %idxs))
// -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_tensor_to_memref(
@@ -59,6 +91,8 @@ func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf
return %1 : f32
}
+// -----
+
// Test case: Folding of dim(tensor.generate %idx) -> %idx
// CHECK-LABEL: func @dim_of_tensor.generate(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
@@ -74,6 +108,8 @@ func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
return %1 : index
}
+// -----
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = constant true
@@ -96,6 +132,8 @@ func @cmpi_equal_operands(%arg0: i64)
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
+// -----
+
// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
@@ -116,6 +154,8 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
return %1 : index
}
+// -----
+
// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
@@ -132,6 +172,8 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
return %1, %2: index, index
}
+// -----
+
// CHECK-LABEL: func @tensor_cast_to_memref
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
@@ -144,6 +186,8 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
return %1 : memref<?x?x16x32xi8>
}
+// -----
+
// CHECK-LABEL: func @subview_of_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
@@ -158,6 +202,8 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
}
+// -----
+
// CHECK-LABEL: func @trivial_subtensor
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor
@@ -167,6 +213,8 @@ func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
return %0 : tensor<4x6x16x32xi8>
}
+// -----
+
// CHECK-LABEL: func @trivial_subtensor_insert
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor
@@ -176,6 +224,8 @@ func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x
return %0 : tensor<4x6x16x32xi8>
}
+// -----
+
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>