diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp | 112 |
1 files changed, 82 insertions, 30 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp index 0143f4f4b62a..6da420b8e0dd 100644 --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -99,6 +99,8 @@ #include "llvm/IR/Type.h" #include "llvm/Pass.h" +#define DEBUG_TYPE "nvptx-lower-args" + using namespace llvm; namespace llvm { @@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) { Value *NewParam; }; SmallVector<IP> ItemsToConvert = {{I, Param}}; - SmallVector<GetElementPtrInst *> GEPsToDelete; - while (!ItemsToConvert.empty()) { - IP I = ItemsToConvert.pop_back_val(); - if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) + SmallVector<Instruction *> InstructionsToDelete; + + auto CloneInstInParamAS = [](const IP &I) -> Value * { + if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { LI->setOperand(0, I.NewParam); - else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { + return LI; + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { SmallVector<Value *, 4> Indices(GEP->indices()); auto *NewGEP = GetElementPtrInst::Create(nullptr, I.NewParam, Indices, GEP->getName(), GEP); NewGEP->setIsInBounds(GEP->isInBounds()); - llvm::for_each(GEP->users(), [NewGEP, &ItemsToConvert](Value *V) { - ItemsToConvert.push_back({cast<Instruction>(V), NewGEP}); - }); - GEPsToDelete.push_back(GEP); - } else - llvm_unreachable("Only Load and GEP can be converted to param AS."); - } - llvm::for_each(GEPsToDelete, - [](GetElementPtrInst *GEP) { GEP->eraseFromParent(); }); -} + return NewGEP; + } + if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { + auto *NewBCType = BC->getType()->getPointerElementType()->getPointerTo( + ADDRESS_SPACE_PARAM); + return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, + BC->getName(), BC); + } + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { + assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); + // Just pass through the argument, the old ASC is no longer needed. + return I.NewParam; + } + llvm_unreachable("Unsupported instruction"); + }; -static bool isALoadChain(Value *Start) { - SmallVector<Value *, 16> ValuesToCheck = {Start}; - while (!ValuesToCheck.empty()) { - Value *V = ValuesToCheck.pop_back_val(); - Instruction *I = dyn_cast<Instruction>(V); - if (!I) - return false; - if (isa<GetElementPtrInst>(I)) - ValuesToCheck.append(I->user_begin(), I->user_end()); - else if (!isa<LoadInst>(I)) - return false; + while (!ItemsToConvert.empty()) { + IP I = ItemsToConvert.pop_back_val(); + Value *NewInst = CloneInstInParamAS(I); + + if (NewInst && NewInst != I.OldInstruction) { + // We've created a new instruction. Queue users of the old instruction to + // be converted and the instruction itself to be deleted. We can't delete + // the old instruction yet, because it's still in use by a load somewhere. + llvm::for_each( + I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) { + ItemsToConvert.push_back({cast<Instruction>(V), NewInst}); + }); + + InstructionsToDelete.push_back(I.OldInstruction); + } } - return true; + + // Now we know that all argument loads are using addresses in parameter space + // and we can finally remove the old instructions in generic AS. Instructions + // scheduled for removal should be processed in reverse order so the ones + // closest to the load are deleted first. Otherwise they may still be in use. + // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will + // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by + // the BitCast. + llvm::for_each(reverse(InstructionsToDelete), + [](Instruction *I) { I->eraseFromParent(); }); } void NVPTXLowerArgs::handleByValParam(Argument *Arg) { @@ -211,9 +233,36 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) { Type *StructType = PType->getElementType(); - if (llvm::all_of(Arg->users(), isALoadChain)) { - // Replace all loads with the loads in param AS. This allows loading the Arg - // directly from parameter AS, without making a temporary copy. + auto IsALoadChain = [&](Value *Start) { + SmallVector<Value *, 16> ValuesToCheck = {Start}; + auto IsALoadChainInstr = [](Value *V) -> bool { + if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V)) + return true; + // ASC to param space are OK, too -- we'll just strip them. + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) { + if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) + return true; + } + return false; + }; + + while (!ValuesToCheck.empty()) { + Value *V = ValuesToCheck.pop_back_val(); + if (!IsALoadChainInstr(V)) { + LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V + << "\n"); + (void)Arg; + return false; + } + if (!isa<LoadInst>(V)) + llvm::append_range(ValuesToCheck, V->users()); + } + return true; + }; + + if (llvm::all_of(Arg->users(), IsALoadChain)) { + // Convert all loads and intermediate operations to use parameter AS and + // skip creation of a local copy of the argument. SmallVector<User *, 16> UsersToUpdate(Arg->users()); Value *ArgInParamAS = new AddrSpaceCastInst( Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), @@ -221,6 +270,7 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) { llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) { convertToParamAS(V, ArgInParamAS); }); + LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); return; } @@ -297,6 +347,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { } } + LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); for (Argument &Arg : F.args()) { if (Arg.getType()->isPointerTy()) { if (Arg.hasByValAttr()) @@ -310,6 +361,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { // Device functions only need to copy byval args into local memory. bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { + LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); for (Argument &Arg : F.args()) if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) handleByValParam(&Arg); |