aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/SymbolTable.cpp')
-rw-r--r--mlir/lib/IR/SymbolTable.cpp21
1 files changed, 15 insertions, 6 deletions
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 4620a5bcb381..8d5ba2e16224 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -161,11 +161,17 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
// TODO: consider if SymbolTable's constructor should behave the same.
if (!symbol->getParentOp()) {
auto &body = symbolTableOp->getRegion(0).front();
- if (insertPt == Block::iterator() || insertPt == body.end())
- insertPt = Block::iterator(body.getTerminator());
-
- assert(insertPt->getParentOp() == symbolTableOp &&
- "expected insertPt to be in the associated module operation");
+ if (insertPt == Block::iterator()) {
+ insertPt = Block::iterator(body.end());
+ } else {
+ assert((insertPt == body.end() ||
+ insertPt->getParentOp() == symbolTableOp) &&
+ "expected insertPt to be in the associated module operation");
+ }
+ // Insert before the terminator, if any.
+ if (insertPt == Block::iterator(body.end()) && !body.empty() &&
+ std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
+ insertPt = std::prev(body.end());
body.getOperations().insert(insertPt, symbol);
}
@@ -291,11 +297,14 @@ void SymbolTable::walkSymbolTables(
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
StringRef symbol) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+ Region &region = symbolTableOp->getRegion(0);
+ if (region.empty())
+ return nullptr;
// Look for a symbol with the given name.
Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(),
symbolTableOp->getContext());
- for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
+ for (auto &op : region.front())
if (getNameIfSymbol(&op, symbolNameId) == symbol)
return &op;
return nullptr;