aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2021-03-22 18:07:09 -0700
committerRiver Riddle <riddleriver@gmail.com>2021-03-22 18:19:23 -0700
commit6d6fe9ccc43d23286b764016bc8b5a4a3ab8f675 (patch)
tree1bb9a5320f871e37164186a43142ad8ae69ee092
parent[mlir] Tune error message for assertion. (diff)
downloadllvm-project-6d6fe9ccc43d23286b764016bc8b5a4a3ab8f675.tar.gz
llvm-project-6d6fe9ccc43d23286b764016bc8b5a4a3ab8f675.tar.bz2
llvm-project-6d6fe9ccc43d23286b764016bc8b5a4a3ab8f675.zip
[mlir][OpAsmFormat] Add support for an "else" group on optional elements
The "else" group of an optional element is a collection of elements that get parsed/printed when the anchor of the main element group is *not* present. This is useful when there is a special syntax when an element is not present. The new syntax for an optional element is shown below: ``` optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?` ``` An example of how this might be used is shown below: ```tablegen def FooOp : ... { let arguments = (ins UnitAttr:$foo); let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?"; } ``` would be formatted as such: ```mlir // When the `foo` attribute is present: foo.op foo_is_present // When the `foo` attribute is not present: foo.op foo_is_absent ``` Differential Revision: https://reviews.llvm.org/D99129
-rw-r--r--mlir/docs/OpDefinitions.md35
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td5
-rw-r--r--mlir/test/mlir-tblgen/op-format-spec.td12
-rw-r--r--mlir/test/mlir-tblgen/op-format.mlir10
-rw-r--r--mlir/tools/mlir-tblgen/OpFormatGen.cpp105
5 files changed, 143 insertions, 24 deletions
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 63b727ae428b..5f413582c698 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -772,8 +772,13 @@ When a variable is optional, the provided value may be null.
In certain situations operations may have "optional" information, e.g.
attributes or an empty set of variadic operands. In these situations a section
of the assembly format can be marked as `optional` based on the presence of this
-information. An optional group is defined by wrapping a set of elements within
-`()` followed by a `?` and has the following requirements:
+information. An optional group is defined as follows:
+
+```
+optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
+```
+
+The `elements` of an optional group have the following requirements:
* The first element of the group must either be a attribute, literal, operand,
or region.
@@ -837,6 +842,32 @@ foo.op is_read_only
foo.op
```
+##### Optional "else" Group
+
+Optional groups also have support for an "else" group of elements. These are
+elements that are parsed/printed if the `anchor` element of the optional group
+is *not* present. Unlike the main element group, the "else" group has no
+restriction on the first element and none of the elements may act as the
+`anchor` for the optional. An example is shown below:
+
+```tablegen
+def FooOp : ... {
+ let arguments = (ins UnitAttr:$foo);
+
+ let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?";
+}
+```
+
+would be formatted as such:
+
+```mlir
+// When the `foo` attribute is present:
+foo.op foo_is_present
+
+// When the `foo` attribute is not present:
+foo.op foo_is_absent
+```
+
#### Requirements
The format specification has a certain set of requirements that must be adhered
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7d48f8d4547a..8be84f2aacbc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1651,6 +1651,11 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
let assemblyFormat = "($attr^)? attr-dict";
}
+def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
+ let arguments = (ins UnitAttr:$isFirstBranchPresent);
+ let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Custom Directives
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 4f5ca63c4e72..8c6bb09f34a3 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -390,6 +390,18 @@ def OptionalInvalidL : TestFormat_Op<[{
def OptionalInvalidM : TestFormat_Op<[{
(` `^)?
}]>, Arguments<(ins)>;
+// CHECK: error: expected '(' to start else branch of optional group
+def OptionalInvalidN : TestFormat_Op<[{
+ ($arg^):
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: expected directive, literal, variable, or optional group
+def OptionalInvalidO : TestFormat_Op<[{
+ ($arg^):(`test`
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: expected '?' after optional group
+def OptionalInvalidP : TestFormat_Op<[{
+ ($arg^):(`test`)
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK-NOT: error
def OptionalValidA : TestFormat_Op<[{
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 8043786faf08..e6f998fa4ac3 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -240,6 +240,16 @@ test.format_optional_result_b_op : i64 -> i64, i64
test.format_optional_result_c_op : (i64) -> (i64, i64)
//===----------------------------------------------------------------------===//
+// Format optional with else
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_optional_else then
+test.format_optional_else then
+
+// CHECK: test.format_optional_else else
+test.format_optional_else else
+
+//===----------------------------------------------------------------------===//
// Format custom directives
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index f474bbfb4f20..abf77a55004e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -348,29 +348,41 @@ private:
namespace {
/// This class represents a group of elements that are optionally emitted based
-/// upon an optional variable of the operation.
+/// upon an optional variable of the operation, and a group of elements that are
+/// emotted when the anchor element is not present.
class OptionalElement : public Element {
public:
- OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
+ OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements,
+ std::vector<std::unique_ptr<Element>> &&elseElements,
unsigned anchor, unsigned parseStart)
- : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
+ : Element{Kind::Optional}, thenElements(std::move(thenElements)),
+ elseElements(std::move(elseElements)), anchor(anchor),
parseStart(parseStart) {}
static bool classof(const Element *element) {
return element->getKind() == Kind::Optional;
}
- /// Return the nested elements of this grouping.
- auto getElements() const { return llvm::make_pointee_range(elements); }
+ /// Return the `then` elements of this grouping.
+ auto getThenElements() const {
+ return llvm::make_pointee_range(thenElements);
+ }
+
+ /// Return the `else` elements of this grouping.
+ auto getElseElements() const {
+ return llvm::make_pointee_range(elseElements);
+ }
/// Return the anchor of this optional group.
- Element *getAnchor() const { return elements[anchor].get(); }
+ Element *getAnchor() const { return thenElements[anchor].get(); }
/// Return the index of the first element that needs to be parsed.
unsigned getParseStart() const { return parseStart; }
private:
- /// The child elements of this optional.
- std::vector<std::unique_ptr<Element>> elements;
+ /// The child elements of `then` branch of this optional.
+ std::vector<std::unique_ptr<Element>> thenElements;
+ /// The child elements of `else` branch of this optional.
+ std::vector<std::unique_ptr<Element>> elseElements;
/// The index of the element that acts as the anchor for the optional group.
unsigned anchor;
/// The index of the first element that is parsed (is not a
@@ -792,7 +804,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- auto elements = optional->getElements();
+ auto elements = optional->getThenElements();
// If the anchor is a unit attribute, it won't be parsed directly so elide
// it.
@@ -803,6 +815,8 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
for (auto &childElement : elements)
if (&childElement != elidedAnchorElement)
genElementParserStorage(&childElement, body);
+ for (auto &childElement : optional->getElseElements())
+ genElementParserStorage(&childElement, body);
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
for (auto &paramElement : custom->getArguments())
@@ -1094,8 +1108,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- auto elements =
- llvm::drop_begin(optional->getElements(), optional->getParseStart());
+ auto elements = llvm::drop_begin(optional->getThenElements(),
+ optional->getParseStart());
// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
@@ -1140,7 +1154,17 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
if (&childElement != elidedAnchorElement)
genElementParser(&childElement, body, attrTypeCtx);
}
- body << " }\n";
+ body << " }";
+
+ // Generate the else elements.
+ auto elseElements = optional->getElseElements();
+ if (!elseElements.empty()) {
+ body << " else {\n";
+ for (Element &childElement : elseElements)
+ genElementParser(&childElement, body, attrTypeCtx);
+ body << " }";
+ }
+ body << "\n";
/// Literals.
} else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
@@ -1778,7 +1802,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
- auto elements = optional->getElements();
+ auto elements = optional->getThenElements();
Element *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
if (anchorAttr && anchorAttr != &*elements.begin() &&
@@ -1793,7 +1817,20 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
lastWasPunctuation);
}
}
- body << " }\n";
+ body << " }";
+
+ // Emit each of the else elements.
+ auto elseElements = optional->getElseElements();
+ if (!elseElements.empty()) {
+ body << " else {\n";
+ for (Element &childElement : elseElements) {
+ genElementPrinter(&childElement, body, op, shouldEmitSpace,
+ lastWasPunctuation);
+ }
+ body << " }";
+ }
+
+ body << "\n";
return;
}
@@ -1911,6 +1948,7 @@ public:
l_paren,
r_paren,
caret,
+ colon,
comma,
equal,
less,
@@ -2065,6 +2103,8 @@ Token FormatLexer::lexToken() {
// Lex punctuation.
case '^':
return formToken(Token::caret, tokStart);
+ case ':':
+ return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '=':
@@ -2393,8 +2433,11 @@ LogicalResult FormatParser::verifyAttributes(
// Traverse into optional groups.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- auto elements = optional->getElements();
- iteratorStack.emplace_back(elements.begin(), elements.end());
+ auto thenElements = optional->getThenElements();
+ iteratorStack.emplace_back(thenElements.begin(), thenElements.end());
+
+ auto elseElements = optional->getElseElements();
+ iteratorStack.emplace_back(elseElements.begin(), elseElements.end());
return ::mlir::success();
}
@@ -2795,13 +2838,31 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
consumeToken();
// Parse the child elements for this optional group.
- std::vector<std::unique_ptr<Element>> elements;
+ std::vector<std::unique_ptr<Element>> thenElements, elseElements;
Optional<unsigned> anchorIdx;
do {
- if (failed(parseOptionalChildElement(elements, anchorIdx)))
+ if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
return ::mlir::failure();
} while (curToken.getKind() != Token::r_paren);
consumeToken();
+
+ // Parse the `else` elements of this optional group.
+ if (curToken.getKind() == Token::colon) {
+ consumeToken();
+ if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
+ "of optional group")))
+ return failure();
+ do {
+ llvm::SMLoc childLoc = curToken.getLoc();
+ elseElements.push_back({});
+ if (failed(parseElement(elseElements.back(), TopLevelContext)) ||
+ failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
+ /*isAnchor=*/false)))
+ return failure();
+ } while (curToken.getKind() != Token::r_paren);
+ consumeToken();
+ }
+
if (failed(parseToken(Token::question, "expected '?' after optional group")))
return ::mlir::failure();
@@ -2811,7 +2872,7 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
// The first parsable element of the group must be able to be parsed in an
// optional fashion.
- auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
+ auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) {
return isa<WhitespaceElement>(element.get());
});
Element *firstElement = parseBegin->get();
@@ -2822,9 +2883,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
"first parsable element of an operand group must be "
"an attribute, literal, operand, or region");
- auto parseStart = parseBegin - elements.begin();
- element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
- parseStart);
+ auto parseStart = parseBegin - thenElements.begin();
+ element = std::make_unique<OptionalElement>(
+ std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart);
return ::mlir::success();
}