aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/Utils.cpp')
-rw-r--r--mlir/lib/Analysis/Utils.cpp139
1 files changed, 118 insertions, 21 deletions
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index c3c63f3a7f91..a4b8ccfc7ad1 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op,
std::reverse(ops->begin(), ops->end());
}
+// Populates 'cst' with FlatAffineConstraints which represent original domain of
+// the loop bounds that define 'ivs'.
+LogicalResult
+ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) {
+ assert(!ivs.empty() && "Cannot have a slice without its IVs");
+ cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
+ for (Value iv : ivs) {
+ AffineForOp loop = getForInductionVarOwner(iv);
+ assert(loop && "Expected affine for");
+ if (failed(cst.addAffineForOpDomain(loop)))
+ return failure();
+ }
+ return success();
+}
+
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
LogicalResult
ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
@@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
values.append(lbOperands[0].begin(), lbOperands[0].end());
cst->reset(numDims, numSymbols, 0, values);
- // Add loop bound constraints for values which are loop IVs and equality
- // constraints for symbols which are constants.
- for (const auto &value : values) {
+ // Add loop bound constraints for values which are loop IVs of the destination
+ // of fusion and equality constraints for symbols which are constants.
+ for (unsigned i = numDims, end = values.size(); i < end; ++i) {
+ Value value = values[i];
assert(cst->containsId(value) && "value expected to be present");
if (isValidSymbol(value)) {
// Check if the symbol is a constant.
@@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
return true;
}
+/// Returns true if it is deterministically verified that the original iteration
+/// space of the slice is contained within the new iteration space that is
+/// created after fusing 'this' slice into its destination.
+Optional<bool> ComputationSliceState::isSliceValid() {
+ // Fast check to determine if the slice is valid. If the following conditions
+ // are verified to be true, slice is declared valid by the fast check:
+ // 1. Each slice loop is a single iteration loop bound in terms of a single
+ // destination loop IV.
+ // 2. Loop bounds of the destination loop IV (from above) and those of the
+ // source loop IV are exactly the same.
+ // If the fast check is inconclusive or false, we proceed with a more
+ // expensive analysis.
+ // TODO: Store the result of the fast check, as it might be used again in
+ // `canRemoveSrcNodeAfterFusion`.
+ Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
+ if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
+ return true;
+
+ // Create constraints for the source loop nest using which slice is computed.
+ FlatAffineConstraints srcConstraints;
+ // TODO: Store the source's domain to avoid computation at each depth.
+ if (failed(getSourceAsConstraints(srcConstraints))) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
+ return llvm::None;
+ }
+ // As the set difference utility currently cannot handle symbols in its
+ // operands, validity of the slice cannot be determined.
+ if (srcConstraints.getNumSymbolIds() > 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
+ return llvm::None;
+ }
+ // TODO: Handle local ids in the source domains while using the 'projectOut'
+ // utility below. Currently, aligning is not done assuming that there will be
+ // no local ids in the source domain.
+ if (srcConstraints.getNumLocalIds() != 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
+ return llvm::None;
+ }
+
+ // Create constraints for the slice loop nest that would be created if the
+ // fusion succeeds.
+ FlatAffineConstraints sliceConstraints;
+ if (failed(getAsConstraints(&sliceConstraints))) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
+ return llvm::None;
+ }
+
+ // Projecting out every dimension other than the 'ivs' to express slice's
+ // domain completely in terms of source's IVs.
+ sliceConstraints.projectOut(ivs.size(),
+ sliceConstraints.getNumIds() - ivs.size());
+
+ LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
+ LLVM_DEBUG(srcConstraints.dump());
+ LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
+ "(expressed in terms of its source's IVs):\n");
+ LLVM_DEBUG(sliceConstraints.dump());
+
+ // TODO: Store 'srcSet' to avoid recalculating for each depth.
+ PresburgerSet srcSet(srcConstraints);
+ PresburgerSet sliceSet(sliceConstraints);
+ PresburgerSet diffSet = sliceSet.subtract(srcSet);
+
+ if (!diffSet.isIntegerEmpty()) {
+ LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
+ return false;
+ }
+ return true;
+}
+
/// Returns true if the computation slice encloses all the iterations of the
/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
/// cannot determine if the slice is maximal or not.
@@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth(
}
/// Computes in 'sliceUnion' the union of all slice bounds computed at
-/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
-/// Returns 'Success' if union was computed, 'failure' otherwise.
-LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
- ArrayRef<Operation *> opsB,
- unsigned loopDepth,
- unsigned numCommonLoops,
- bool isBackwardSlice,
- ComputationSliceState *sliceUnion) {
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
+/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
+/// union was computed correctly, an appropriate failure otherwise.
+SliceComputationResult
+mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
+ unsigned loopDepth, unsigned numCommonLoops,
+ bool isBackwardSlice,
+ ComputationSliceState *sliceUnion) {
// Compute the union of slice bounds between all pairs in 'opsA' and
// 'opsB' in 'sliceUnionCst'.
FlatAffineConstraints sliceUnionCst;
@@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
(isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
@@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
/*allowRAR=*/readReadAccesses);
if (result.value == DependenceResult::Failure) {
LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
if (result.value == DependenceResult::NoDependence)
continue;
@@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute slice bound constraints\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
continue;
@@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute slice bound constraints\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
// Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
@@ -802,9 +888,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
// to unionBoundingBox below expects constraints for each Loop IV, even
// if they are the unsliced full loop bounds added here.
if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
- return failure();
+ return SliceComputationResult::GenericFailure;
if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
- return failure();
+ return SliceComputationResult::GenericFailure;
}
// Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
if (sliceUnionCst.getNumLocalIds() > 0 ||
@@ -812,14 +898,14 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute union bounding box of slice bounds\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
}
}
// Empty union.
if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
- return failure();
+ return SliceComputationResult::GenericFailure;
// Gather loops surrounding ops from loop nest where slice will be inserted.
SmallVector<Operation *, 4> ops;
@@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
getInnermostCommonLoopDepth(ops, &surroundingLoops);
if (loopDepth > innermostCommonLoopDepth) {
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
- return failure();
+ return SliceComputationResult::GenericFailure;
}
// Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
@@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
// canonicalization.
sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
- return success();
+
+ // Check if the slice computed is valid. Return success only if it is verified
+ // that the slice is valid, otherwise return appropriate failure status.
+ Optional<bool> isSliceValid = sliceUnion->isSliceValid();
+ if (!isSliceValid.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
+ return SliceComputationResult::GenericFailure;
+ }
+ if (!isSliceValid.getValue())
+ return SliceComputationResult::IncorrectSliceFailure;
+
+ return SliceComputationResult::Success;
}
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";