aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp112
1 files changed, 82 insertions, 30 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 0143f4f4b62a..6da420b8e0dd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -99,6 +99,8 @@
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
+#define DEBUG_TYPE "nvptx-lower-args"
+
using namespace llvm;
namespace llvm {
@@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
Value *NewParam;
};
SmallVector<IP> ItemsToConvert = {{I, Param}};
- SmallVector<GetElementPtrInst *> GEPsToDelete;
- while (!ItemsToConvert.empty()) {
- IP I = ItemsToConvert.pop_back_val();
- if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction))
+ SmallVector<Instruction *> InstructionsToDelete;
+
+ auto CloneInstInParamAS = [](const IP &I) -> Value * {
+ if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
LI->setOperand(0, I.NewParam);
- else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
+ return LI;
+ }
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
SmallVector<Value *, 4> Indices(GEP->indices());
auto *NewGEP = GetElementPtrInst::Create(nullptr, I.NewParam, Indices,
GEP->getName(), GEP);
NewGEP->setIsInBounds(GEP->isInBounds());
- llvm::for_each(GEP->users(), [NewGEP, &ItemsToConvert](Value *V) {
- ItemsToConvert.push_back({cast<Instruction>(V), NewGEP});
- });
- GEPsToDelete.push_back(GEP);
- } else
- llvm_unreachable("Only Load and GEP can be converted to param AS.");
- }
- llvm::for_each(GEPsToDelete,
- [](GetElementPtrInst *GEP) { GEP->eraseFromParent(); });
-}
+ return NewGEP;
+ }
+ if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
+ auto *NewBCType = BC->getType()->getPointerElementType()->getPointerTo(
+ ADDRESS_SPACE_PARAM);
+ return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
+ BC->getName(), BC);
+ }
+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
+ assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
+ // Just pass through the argument, the old ASC is no longer needed.
+ return I.NewParam;
+ }
+ llvm_unreachable("Unsupported instruction");
+ };
-static bool isALoadChain(Value *Start) {
- SmallVector<Value *, 16> ValuesToCheck = {Start};
- while (!ValuesToCheck.empty()) {
- Value *V = ValuesToCheck.pop_back_val();
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return false;
- if (isa<GetElementPtrInst>(I))
- ValuesToCheck.append(I->user_begin(), I->user_end());
- else if (!isa<LoadInst>(I))
- return false;
+ while (!ItemsToConvert.empty()) {
+ IP I = ItemsToConvert.pop_back_val();
+ Value *NewInst = CloneInstInParamAS(I);
+
+ if (NewInst && NewInst != I.OldInstruction) {
+ // We've created a new instruction. Queue users of the old instruction to
+ // be converted and the instruction itself to be deleted. We can't delete
+ // the old instruction yet, because it's still in use by a load somewhere.
+ llvm::for_each(
+ I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) {
+ ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+ });
+
+ InstructionsToDelete.push_back(I.OldInstruction);
+ }
}
- return true;
+
+ // Now we know that all argument loads are using addresses in parameter space
+ // and we can finally remove the old instructions in generic AS. Instructions
+ // scheduled for removal should be processed in reverse order so the ones
+ // closest to the load are deleted first. Otherwise they may still be in use.
+ // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
+ // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
+ // the BitCast.
+ llvm::for_each(reverse(InstructionsToDelete),
+ [](Instruction *I) { I->eraseFromParent(); });
}
void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
@@ -211,9 +233,36 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
Type *StructType = PType->getElementType();
- if (llvm::all_of(Arg->users(), isALoadChain)) {
- // Replace all loads with the loads in param AS. This allows loading the Arg
- // directly from parameter AS, without making a temporary copy.
+ auto IsALoadChain = [&](Value *Start) {
+ SmallVector<Value *, 16> ValuesToCheck = {Start};
+ auto IsALoadChainInstr = [](Value *V) -> bool {
+ if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
+ return true;
+ // ASC to param space are OK, too -- we'll just strip them.
+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
+ if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
+ return true;
+ }
+ return false;
+ };
+
+ while (!ValuesToCheck.empty()) {
+ Value *V = ValuesToCheck.pop_back_val();
+ if (!IsALoadChainInstr(V)) {
+ LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
+ << "\n");
+ (void)Arg;
+ return false;
+ }
+ if (!isa<LoadInst>(V))
+ llvm::append_range(ValuesToCheck, V->users());
+ }
+ return true;
+ };
+
+ if (llvm::all_of(Arg->users(), IsALoadChain)) {
+ // Convert all loads and intermediate operations to use parameter AS and
+ // skip creation of a local copy of the argument.
SmallVector<User *, 16> UsersToUpdate(Arg->users());
Value *ArgInParamAS = new AddrSpaceCastInst(
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
@@ -221,6 +270,7 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) {
convertToParamAS(V, ArgInParamAS);
});
+ LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
return;
}
@@ -297,6 +347,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
}
}
+ LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
for (Argument &Arg : F.args()) {
if (Arg.getType()->isPointerTy()) {
if (Arg.hasByValAttr())
@@ -310,6 +361,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
// Device functions only need to copy byval args into local memory.
bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
+ LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
for (Argument &Arg : F.args())
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
handleByValParam(&Arg);