diff options
Diffstat (limited to 'mlir/lib/Analysis/Utils.cpp')
-rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 139 |
1 files changed, 118 insertions, 21 deletions
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index c3c63f3a7f91..a4b8ccfc7ad1 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op, std::reverse(ops->begin(), ops->end()); } +// Populates 'cst' with FlatAffineConstraints which represent original domain of +// the loop bounds that define 'ivs'. +LogicalResult +ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) { + assert(!ivs.empty() && "Cannot have a slice without its IVs"); + cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); + for (Value iv : ivs) { + AffineForOp loop = getForInductionVarOwner(iv); + assert(loop && "Expected affine for"); + if (failed(cst.addAffineForOpDomain(loop))) + return failure(); + } + return success(); +} + // Populates 'cst' with FlatAffineConstraints which represent slice bounds. LogicalResult ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { @@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { values.append(lbOperands[0].begin(), lbOperands[0].end()); cst->reset(numDims, numSymbols, 0, values); - // Add loop bound constraints for values which are loop IVs and equality - // constraints for symbols which are constants. - for (const auto &value : values) { + // Add loop bound constraints for values which are loop IVs of the destination + // of fusion and equality constraints for symbols which are constants. + for (unsigned i = numDims, end = values.size(); i < end; ++i) { + Value value = values[i]; assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. @@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const { return true; } +/// Returns true if it is deterministically verified that the original iteration +/// space of the slice is contained within the new iteration space that is +/// created after fusing 'this' slice into its destination. +Optional<bool> ComputationSliceState::isSliceValid() { + // Fast check to determine if the slice is valid. If the following conditions + // are verified to be true, slice is declared valid by the fast check: + // 1. Each slice loop is a single iteration loop bound in terms of a single + // destination loop IV. + // 2. Loop bounds of the destination loop IV (from above) and those of the + // source loop IV are exactly the same. + // If the fast check is inconclusive or false, we proceed with a more + // expensive analysis. + // TODO: Store the result of the fast check, as it might be used again in + // `canRemoveSrcNodeAfterFusion`. + Optional<bool> isValidFastCheck = isSliceMaximalFastCheck(); + if (isValidFastCheck.hasValue() && isValidFastCheck.getValue()) + return true; + + // Create constraints for the source loop nest using which slice is computed. + FlatAffineConstraints srcConstraints; + // TODO: Store the source's domain to avoid computation at each depth. + if (failed(getSourceAsConstraints(srcConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); + return llvm::None; + } + // As the set difference utility currently cannot handle symbols in its + // operands, validity of the slice cannot be determined. + if (srcConstraints.getNumSymbolIds() > 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); + return llvm::None; + } + // TODO: Handle local ids in the source domains while using the 'projectOut' + // utility below. Currently, aligning is not done assuming that there will be + // no local ids in the source domain. + if (srcConstraints.getNumLocalIds() != 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); + return llvm::None; + } + + // Create constraints for the slice loop nest that would be created if the + // fusion succeeds. + FlatAffineConstraints sliceConstraints; + if (failed(getAsConstraints(&sliceConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); + return llvm::None; + } + + // Projecting out every dimension other than the 'ivs' to express slice's + // domain completely in terms of source's IVs. + sliceConstraints.projectOut(ivs.size(), + sliceConstraints.getNumIds() - ivs.size()); + + LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); + LLVM_DEBUG(srcConstraints.dump()); + LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " + "(expressed in terms of its source's IVs):\n"); + LLVM_DEBUG(sliceConstraints.dump()); + + // TODO: Store 'srcSet' to avoid recalculating for each depth. + PresburgerSet srcSet(srcConstraints); + PresburgerSet sliceSet(sliceConstraints); + PresburgerSet diffSet = sliceSet.subtract(srcSet); + + if (!diffSet.isIntegerEmpty()) { + LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); + return false; + } + return true; +} + /// Returns true if the computation slice encloses all the iterations of the /// sliced loop nest. Returns false if it does not. Returns llvm::None if it /// cannot determine if the slice is maximal or not. @@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth( } /// Computes in 'sliceUnion' the union of all slice bounds computed at -/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. -/// Returns 'Success' if union was computed, 'failure' otherwise. -LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, - ArrayRef<Operation *> opsB, - unsigned loopDepth, - unsigned numCommonLoops, - bool isBackwardSlice, - ComputationSliceState *sliceUnion) { +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and +/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if +/// union was computed correctly, an appropriate failure otherwise. +SliceComputationResult +mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB, + unsigned loopDepth, unsigned numCommonLoops, + bool isBackwardSlice, + ComputationSliceState *sliceUnion) { // Compute the union of slice bounds between all pairs in 'opsA' and // 'opsB' in 'sliceUnionCst'. FlatAffineConstraints sliceUnionCst; @@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) || (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) { LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); - return failure(); + return SliceComputationResult::GenericFailure; } bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) && @@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, /*allowRAR=*/readReadAccesses); if (result.value == DependenceResult::Failure) { LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); - return failure(); + return SliceComputationResult::GenericFailure; } if (result.value == DependenceResult::NoDependence) continue; @@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); - return failure(); + return SliceComputationResult::GenericFailure; } assert(sliceUnionCst.getNumDimAndSymbolIds() > 0); continue; @@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); - return failure(); + return SliceComputationResult::GenericFailure; } // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. @@ -802,9 +888,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, // to unionBoundingBox below expects constraints for each Loop IV, even // if they are the unsliced full loop bounds added here. if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) - return failure(); + return SliceComputationResult::GenericFailure; if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) - return failure(); + return SliceComputationResult::GenericFailure; } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. if (sliceUnionCst.getNumLocalIds() > 0 || @@ -812,14 +898,14 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds\n"); - return failure(); + return SliceComputationResult::GenericFailure; } } } // Empty union. if (sliceUnionCst.getNumDimAndSymbolIds() == 0) - return failure(); + return SliceComputationResult::GenericFailure; // Gather loops surrounding ops from loop nest where slice will be inserted. SmallVector<Operation *, 4> ops; @@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); - return failure(); + return SliceComputationResult::GenericFailure; } // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. @@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, // canonicalization. sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); - return success(); + + // Check if the slice computed is valid. Return success only if it is verified + // that the slice is valid, otherwise return appropriate failure status. + Optional<bool> isSliceValid = sliceUnion->isSliceValid(); + if (!isSliceValid.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); + return SliceComputationResult::GenericFailure; + } + if (!isSliceValid.getValue()) + return SliceComputationResult::IncorrectSliceFailure; + + return SliceComputationResult::Success; } const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; |