Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpaqueType with format strings #391

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ bool isPointerWideType(mlir::Type type);
/// Give the name of the EmitC reference attribute.
StringRef getReferenceAttributeName();

// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;

} // namespace emitc
} // namespace mlir

Expand Down
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1168,11 +1168,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
}];

let extraClassDeclaration = [{
// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;

FailureOr<SmallVector<ReplacementItem>> parseFormatString();
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
}];

let arguments = (ins StrAttr:$value,
Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,16 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
```
}];

let parameters = (ins StringRefParameter<"the opaque value">:$value);
let assemblyFormat = "`<` $value `>`";
let parameters = (ins StringRefParameter<"the opaque value">:$value,
OptionalArrayRefParameter<"Type">:$fmtArgs);
let assemblyFormat = "`<` $value (`,` custom<VariadicTypeFmtArgs>($fmtArgs)^)? `>`";
let genVerifyDecl = 1;

let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector<Type>{}); }] >];

let extraClassDeclaration = [{
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
}];
}

def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
Expand Down
154 changes: 105 additions & 49 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace mlir::emitc;
Expand Down Expand Up @@ -154,6 +155,64 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
return success();
}

/// Parse a format string and return a list of its parts.
/// A part is either a StringRef that has to be printed as-is, or
/// a Placeholder which requires printing the next operand of the VerbatimOp.
/// In the format string, all `{}` are replaced by Placeholders, except if the
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
template <class ArgType>
FailureOr<SmallVector<ReplacementItem>>
parseFormatString(StringRef toParse, ArgType fmtArgs,
std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
emitError = {}) {
SmallVector<ReplacementItem> items;

// If there are not operands, the format string is not interpreted.
if (fmtArgs.empty()) {
items.push_back(toParse);
return items;
}

while (!toParse.empty()) {
size_t idx = toParse.find('{');
if (idx == StringRef::npos) {
// No '{'
items.push_back(toParse);
break;
}
if (idx > 0) {
// Take all chars excluding the '{'.
items.push_back(toParse.take_front(idx));
toParse = toParse.drop_front(idx);
continue;
}
if (toParse.size() < 2) {
// '{' is last character
items.push_back(toParse);
break;
}
// toParse contains at least two characters and starts with `{`.
char nextChar = toParse[1];
if (nextChar == '{') {
// Double '{{' -> '{' (escaping).
items.push_back(toParse.take_front(1));
toParse = toParse.drop_front(2);
continue;
}
if (nextChar == '}') {
items.push_back(Placeholder{});
toParse = toParse.drop_front(2);
continue;
}

if (emitError.has_value()) {
return (*emitError)() << "expected '}' after unescaped '{'";
}
return failure();
}
return items;
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -914,7 +973,11 @@ LogicalResult emitc::SubscriptOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult emitc::VerbatimOp::verify() {
FailureOr<SmallVector<ReplacementItem>> fmt = parseFormatString();
auto errorCallback = [&]() -> InFlightDiagnostic {
return this->emitOpError();
};
FailureOr<SmallVector<ReplacementItem>> fmt =
::parseFormatString(getValue(), getFmtArgs(), errorCallback);
if (failed(fmt))
return failure();

Expand All @@ -929,56 +992,29 @@ LogicalResult emitc::VerbatimOp::verify() {
return success();
}

/// Parse a format string and return a list of its parts.
/// A part is either a StringRef that has to be printed as-is, or
/// a Placeholder which requires printing the next operand of the VerbatimOp.
/// In the format string, all `{}` are replaced by Placeholders, except if the
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>>
emitc::VerbatimOp::parseFormatString() {
SmallVector<ReplacementItem> items;
static ParseResult parseVariadicTypeFmtArgs(AsmParser &p,
SmallVector<Type> &params) {
Type type;
if (p.parseType(type))
return failure();

// If there are not operands, the format string is not interpreted.
if (getFmtArgs().empty()) {
items.push_back(getValue());
return items;
params.push_back(type);
while (succeeded(p.parseOptionalComma())) {
if (p.parseType(type))
return failure();
params.push_back(type);
}

StringRef toParse = getValue();
while (!toParse.empty()) {
size_t idx = toParse.find('{');
if (idx == StringRef::npos) {
// No '{'
items.push_back(toParse);
break;
}
if (idx > 0) {
// Take all chars excluding the '{'.
items.push_back(toParse.take_front(idx));
toParse = toParse.drop_front(idx);
continue;
}
if (toParse.size() < 2) {
// '{' is last character
items.push_back(toParse);
break;
}
// toParse contains at least two characters and starts with `{`.
char nextChar = toParse[1];
if (nextChar == '{') {
// Double '{{' -> '{' (escaping).
items.push_back(toParse.take_front(1));
toParse = toParse.drop_front(2);
continue;
}
if (nextChar == '}') {
items.push_back(Placeholder{});
toParse = toParse.drop_front(2);
continue;
}
return emitOpError() << "expected '}' after unescaped '{'";
}
return items;
return success();
}

static void printVariadicTypeFmtArgs(AsmPrinter &p, ArrayRef<Type> params) {
llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); });
}

FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
// Error checking is done in verify.
return ::parseFormatString(getValue(), getFmtArgs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1072,17 +1108,37 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,

LogicalResult mlir::emitc::OpaqueType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::StringRef value) {
llvm::StringRef value, ArrayRef<Type> fmtArgs) {
if (value.empty()) {
return emitError() << "expected non empty string in !emitc.opaque type";
}
if (value.back() == '*') {
return emitError() << "pointer not allowed as outer type with "
"!emitc.opaque, use !emitc.ptr instead";
}

FailureOr<SmallVector<ReplacementItem>> fmt =
::parseFormatString(value, fmtArgs, emitError);
if (failed(fmt))
return failure();

size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
return std::holds_alternative<Placeholder>(item);
});

if (numPlaceholders != fmtArgs.size()) {
return emitError()
<< "requires operands for each placeholder in the format string";
}

return success();
}

FailureOr<SmallVector<ReplacementItem>> emitc::OpaqueType::parseFormatString() {
// Error checking is done in verify.
return ::parseFormatString(getValue(), getFmtArgs());
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 19 additions & 2 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,14 @@ static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();

FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>> items =
FailureOr<SmallVector<ReplacementItem>> items =
verbatimOp.parseFormatString();
if (failed(items))
return failure();

auto fmtArg = verbatimOp.getFmtArgs().begin();

for (emitc::VerbatimOp::ReplacementItem &item : *items) {
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
} else {
Expand Down Expand Up @@ -1728,6 +1728,23 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
FailureOr<SmallVector<ReplacementItem>> items = oType.parseFormatString();
if (failed(items))
return failure();

auto fmtArg = oType.getFmtArgs().begin();
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
} else {
if (failed(emitType(loc, *fmtArg++))) {
mgehre-amd marked this conversation as resolved.
Show resolved Hide resolved
return failure();
}
}
}

return success();

os << oType.getValue();
return success();
}
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() {

// -----

// expected-error @+1 {{expected non-function type}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", "string">) {
return
}

// -----

// expected-error @+1 {{requires operands for each placeholder in the format string}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"a", f32>) {
return
}

// -----

// expected-error @+1 {{requires operands for each placeholder in the format string}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", f32>) {
return
}

// -----

// expected-error @+1 {{expected '}' after unescaped '{'}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{ ", i32>) {
return
}

// -----

func.func @illegal_array_missing_spec(
// expected-error @+1 {{expected non-function type}}
%arg0: !emitc.array<>) {
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/EmitC/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ func.func @opaque_types() {
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"SmallVector<int*, 4>">
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector<int*, 4>">]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}", i32>
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", i32>>]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>]
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}, {}", i32, f32>>]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}"
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}">>]} : () -> ()

return
}
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Target/Cpp/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ func.func @opaque_types() {
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"status_t">>]} : () -> ()
// CHECK-NEXT: f<std::vector<std::string>>();
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
// CHECK: f<float>()
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", f32>>]} : () -> ()
// CHECK: f<int16_t {>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {{", si16>>]} : () -> ()
// CHECK: f<int8_t {>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {", i8>>]} : () -> ()
// CHECK: f<status_t>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", !emitc<opaque<"status_t">> >>]} : () -> ()
// CHECK: f<top<nested<float>,int32_t>>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"top<{},{}>", !emitc<opaque<"nested<{}>", f32>>, i32>>]} : () -> ()

return
}
Expand Down