diff options
Diffstat (limited to 'mlir/lib/IR/SymbolTable.cpp')
-rw-r--r-- | mlir/lib/IR/SymbolTable.cpp | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 4620a5bcb381..8d5ba2e16224 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -161,11 +161,17 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { // TODO: consider if SymbolTable's constructor should behave the same. if (!symbol->getParentOp()) { auto &body = symbolTableOp->getRegion(0).front(); - if (insertPt == Block::iterator() || insertPt == body.end()) - insertPt = Block::iterator(body.getTerminator()); - - assert(insertPt->getParentOp() == symbolTableOp && - "expected insertPt to be in the associated module operation"); + if (insertPt == Block::iterator()) { + insertPt = Block::iterator(body.end()); + } else { + assert((insertPt == body.end() || + insertPt->getParentOp() == symbolTableOp) && + "expected insertPt to be in the associated module operation"); + } + // Insert before the terminator, if any. + if (insertPt == Block::iterator(body.end()) && !body.empty() && + std::prev(body.end())->hasTrait<OpTrait::IsTerminator>()) + insertPt = std::prev(body.end()); body.getOperations().insert(insertPt, symbol); } @@ -291,11 +297,14 @@ void SymbolTable::walkSymbolTables( Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, StringRef symbol) { assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); + Region ®ion = symbolTableOp->getRegion(0); + if (region.empty()) + return nullptr; // Look for a symbol with the given name. Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), symbolTableOp->getContext()); - for (auto &op : symbolTableOp->getRegion(0).front().without_terminator()) + for (auto &op : region.front()) if (getNameIfSymbol(&op, symbolNameId) == symbol) return &op; return nullptr; |