diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index bc82f58a7ee95c..d9cca43081c98c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -52,6 +52,10 @@ bool isPointerWideType(mlir::Type type); /// Give the name of the EmitC reference attribute. StringRef getReferenceAttributeName(); +// Either a literal string, or an placeholder for the fmtArgs. +struct Placeholder {}; +using ReplacementItem = std::variant; + } // namespace emitc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0de8787ba1dc8f..1a1b58e3cbf386 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1168,11 +1168,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { }]; let extraClassDeclaration = [{ - // Either a literal string, or an placeholder for the fmtArgs. - struct Placeholder {}; - using ReplacementItem = std::variant; - - FailureOr> parseFormatString(); + FailureOr> parseFormatString(); }]; let arguments = (ins StrAttr:$value, diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 79f6d34fc91b13..0fbacd440a91de 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -99,9 +99,16 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { ``` }]; - let parameters = (ins StringRefParameter<"the opaque value">:$value); - let assemblyFormat = "`<` $value `>`"; + let parameters = (ins StringRefParameter<"the opaque value">:$value, + OptionalArrayRefParameter<"Type">:$fmtArgs); + let assemblyFormat = "`<` $value (`,` custom($fmtArgs)^)? `>`"; let genVerifyDecl = 1; + + let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector{}); }] >]; + + let extraClassDeclaration = [{ + FailureOr> parseFormatString(); + }]; } def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 7bc40b4f555cc6..ee44f524c91428 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::emitc; @@ -154,6 +155,64 @@ static LogicalResult verifyInitializationAttribute(Operation *op, return success(); } +/// Parse a format string and return a list of its parts. +/// A part is either a StringRef that has to be printed as-is, or +/// a Placeholder which requires printing the next operand of the VerbatimOp. +/// In the format string, all `{}` are replaced by Placeholders, except if the +/// `{` is escaped by `{{` - then it doesn't start a placeholder. +template +FailureOr> +parseFormatString(StringRef toParse, ArgType fmtArgs, + std::optional> + emitError = {}) { + SmallVector items; + + // If there are not operands, the format string is not interpreted. + if (fmtArgs.empty()) { + items.push_back(toParse); + return items; + } + + while (!toParse.empty()) { + size_t idx = toParse.find('{'); + if (idx == StringRef::npos) { + // No '{' + items.push_back(toParse); + break; + } + if (idx > 0) { + // Take all chars excluding the '{'. + items.push_back(toParse.take_front(idx)); + toParse = toParse.drop_front(idx); + continue; + } + if (toParse.size() < 2) { + // '{' is last character + items.push_back(toParse); + break; + } + // toParse contains at least two characters and starts with `{`. + char nextChar = toParse[1]; + if (nextChar == '{') { + // Double '{{' -> '{' (escaping). + items.push_back(toParse.take_front(1)); + toParse = toParse.drop_front(2); + continue; + } + if (nextChar == '}') { + items.push_back(Placeholder{}); + toParse = toParse.drop_front(2); + continue; + } + + if (emitError.has_value()) { + return (*emitError)() << "expected '}' after unescaped '{'"; + } + return failure(); + } + return items; +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -914,7 +973,11 @@ LogicalResult emitc::SubscriptOp::verify() { //===----------------------------------------------------------------------===// LogicalResult emitc::VerbatimOp::verify() { - FailureOr> fmt = parseFormatString(); + auto errorCallback = [&]() -> InFlightDiagnostic { + return this->emitOpError(); + }; + FailureOr> fmt = + ::parseFormatString(getValue(), getFmtArgs(), errorCallback); if (failed(fmt)) return failure(); @@ -929,56 +992,29 @@ LogicalResult emitc::VerbatimOp::verify() { return success(); } -/// Parse a format string and return a list of its parts. -/// A part is either a StringRef that has to be printed as-is, or -/// a Placeholder which requires printing the next operand of the VerbatimOp. -/// In the format string, all `{}` are replaced by Placeholders, except if the -/// `{` is escaped by `{{` - then it doesn't start a placeholder. -FailureOr> -emitc::VerbatimOp::parseFormatString() { - SmallVector items; +static ParseResult parseVariadicTypeFmtArgs(AsmParser &p, + SmallVector ¶ms) { + Type type; + if (p.parseType(type)) + return failure(); - // If there are not operands, the format string is not interpreted. - if (getFmtArgs().empty()) { - items.push_back(getValue()); - return items; + params.push_back(type); + while (succeeded(p.parseOptionalComma())) { + if (p.parseType(type)) + return failure(); + params.push_back(type); } - StringRef toParse = getValue(); - while (!toParse.empty()) { - size_t idx = toParse.find('{'); - if (idx == StringRef::npos) { - // No '{' - items.push_back(toParse); - break; - } - if (idx > 0) { - // Take all chars excluding the '{'. - items.push_back(toParse.take_front(idx)); - toParse = toParse.drop_front(idx); - continue; - } - if (toParse.size() < 2) { - // '{' is last character - items.push_back(toParse); - break; - } - // toParse contains at least two characters and starts with `{`. - char nextChar = toParse[1]; - if (nextChar == '{') { - // Double '{{' -> '{' (escaping). - items.push_back(toParse.take_front(1)); - toParse = toParse.drop_front(2); - continue; - } - if (nextChar == '}') { - items.push_back(Placeholder{}); - toParse = toParse.drop_front(2); - continue; - } - return emitOpError() << "expected '}' after unescaped '{'"; - } - return items; + return success(); +} + +static void printVariadicTypeFmtArgs(AsmPrinter &p, ArrayRef params) { + llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); }); +} + +FailureOr> emitc::VerbatimOp::parseFormatString() { + // Error checking is done in verify. + return ::parseFormatString(getValue(), getFmtArgs()); } //===----------------------------------------------------------------------===// @@ -1072,7 +1108,7 @@ emitc::ArrayType::cloneWith(std::optional> shape, LogicalResult mlir::emitc::OpaqueType::verify( llvm::function_ref emitError, - llvm::StringRef value) { + llvm::StringRef value, ArrayRef fmtArgs) { if (value.empty()) { return emitError() << "expected non empty string in !emitc.opaque type"; } @@ -1080,9 +1116,29 @@ LogicalResult mlir::emitc::OpaqueType::verify( return emitError() << "pointer not allowed as outer type with " "!emitc.opaque, use !emitc.ptr instead"; } + + FailureOr> fmt = + ::parseFormatString(value, fmtArgs, emitError); + if (failed(fmt)) + return failure(); + + size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) { + return std::holds_alternative(item); + }); + + if (numPlaceholders != fmtArgs.size()) { + return emitError() + << "requires operands for each placeholder in the format string"; + } + return success(); } +FailureOr> emitc::OpaqueType::parseFormatString() { + // Error checking is done in verify. + return ::parseFormatString(getValue(), getFmtArgs()); +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 73256451ef1487..bc73d415e6e8c8 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -512,14 +512,14 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::VerbatimOp verbatimOp) { raw_ostream &os = emitter.ostream(); - FailureOr> items = + FailureOr> items = verbatimOp.parseFormatString(); if (failed(items)) return failure(); auto fmtArg = verbatimOp.getFmtArgs().begin(); - for (emitc::VerbatimOp::ReplacementItem &item : *items) { + for (ReplacementItem &item : *items) { if (auto *str = std::get_if(&item)) { os << *str; } else { @@ -1728,6 +1728,23 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto tType = dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); if (auto oType = dyn_cast(type)) { + FailureOr> items = oType.parseFormatString(); + if (failed(items)) + return failure(); + + auto fmtArg = oType.getFmtArgs().begin(); + for (ReplacementItem &item : *items) { + if (auto *str = std::get_if(&item)) { + os << *str; + } else { + if (failed(emitType(loc, *fmtArg++))) { + return failure(); + } + } + } + + return success(); + os << oType.getValue(); return success(); } diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index ee59d90bf7f617..616f0480a19d91 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() { // ----- +// expected-error @+1 {{expected non-function type}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", "string">) { + return +} + +// ----- + +// expected-error @+1 {{requires operands for each placeholder in the format string}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"a", f32>) { + return +} + +// ----- + + // expected-error @+1 {{requires operands for each placeholder in the format string}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", f32>) { + return +} + +// ----- + +// expected-error @+1 {{expected '}' after unescaped '{'}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{ ", i32>) { + return +} + +// ----- + func.func @illegal_array_missing_spec( // expected-error @+1 {{expected non-function type}} %arg0: !emitc.array<>) { diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir index b53976eff84cad..eca23c75263ee1 100644 --- a/mlir/test/Dialect/EmitC/types.mlir +++ b/mlir/test/Dialect/EmitC/types.mlir @@ -38,6 +38,12 @@ func.func @opaque_types() { emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector">]} : () -> () // CHECK-NEXT: !emitc.opaque<"SmallVector"> emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector">]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}", i32> + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>] + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}" + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () return } diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir index deda383b3b0a72..336dfacaa183a4 100644 --- a/mlir/test/Target/Cpp/types.mlir +++ b/mlir/test/Target/Cpp/types.mlir @@ -12,6 +12,16 @@ func.func @opaque_types() { emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () // CHECK-NEXT: f>(); emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector">]} : () -> () + // CHECK: f() + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc> >>]} : () -> () + // CHECK: f,int32_t>>(); + emitc.call_opaque "f"() {template_args = [!emitc", !emitc", f32>>, i32>>]} : () -> () return }