Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix line breaking in the case of dictionary attributes placed in an array #79

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,12 @@ class AsmPrinter::Impl {

/// Print the given attribute or an alias.
void printAttribute(Attribute attr,
AttrTypeElision typeElision = AttrTypeElision::Never);
AttrTypeElision typeElision = AttrTypeElision::Never,
SmallString<16> separator = StringRef(", "));
/// Print the given attribute without considering an alias.
void printAttributeImpl(Attribute attr,
AttrTypeElision typeElision = AttrTypeElision::Never);
AttrTypeElision typeElision = AttrTypeElision::Never,
SmallString<16> separator = StringRef(", "));

/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
Expand Down Expand Up @@ -422,8 +424,10 @@ class AsmPrinter::Impl {
protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {},
unsigned currentIndent = 0, bool withKeyword = false);
void printNamedAttribute(NamedAttribute attr);
unsigned currentIndent = 0,
bool withKeyword = false);
void printNamedAttribute(NamedAttribute attr,
SmallString<16> separator = StringRef(", "));
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
Expand Down Expand Up @@ -2110,7 +2114,8 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
}

void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
AttrTypeElision typeElision,
SmallString<16> separator) {
if (!attr) {
os << "<<NULL ATTRIBUTE>>";
return;
Expand All @@ -2119,11 +2124,12 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
// Try to print an alias for this attribute.
if (succeeded(printAlias(attr)))
return;
return printAttributeImpl(attr, typeElision);
return printAttributeImpl(attr, typeElision, separator);
}

void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
AttrTypeElision typeElision) {
AttrTypeElision typeElision,
SmallString<16> separator) {
if (!isa<BuiltinDialect>(attr.getDialect())) {
printDialectAttribute(attr);
} else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
Expand All @@ -2134,8 +2140,15 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
return;
} else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
os << '{';
interleaveComma(dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); });
if (printerFlags.getNewlineAfterAttrLimit() &&
dictAttr.size() > *printerFlags.getNewlineAfterAttrLimit() &&
separator.size() > 2) {
separator.reserve(separator.capacity() + 2);
separator.push_back(' ');
}
interleave(
dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); }, separator);
os << '}';

} else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
Expand All @@ -2147,9 +2160,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
return;
}

// Only print attributes as unsigned if they are explicitly unsigned or are
// signless 1-bit values. Indexes, signed values, and multi-bit signless
// values print as signed.
// Only print attributes as unsigned if they are explicitly unsigned or
// are signless 1-bit values. Indexes, signed values, and multi-bit
// signless values print as signed.
bool isUnsigned =
intType.isUnsignedInteger() || intType.isSignlessInteger(1);
intAttr.getValue().print(os, !isUnsigned);
Expand All @@ -2170,9 +2183,30 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,

} else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
os << '[';
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May);
});
if (printerFlags.getNewlineAfterAttrLimit() &&
arrayAttr.size() > *printerFlags.getNewlineAfterAttrLimit() &&
separator.size() > 2) {
separator.reserve(separator.capacity() + 2);
separator.push_back(' ');
}
bool isDictAttrPresent = false;
for (auto attr : arrayAttr.getValue()) {
if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
isDictAttrPresent = true;
}
}
if (isDictAttrPresent) {
interleave(
arrayAttr.getValue(),
[&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May, separator);
},
separator);
} else {
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May, separator);
});
}
os << ']';

} else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
Expand Down Expand Up @@ -2590,7 +2624,7 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
attrs.size() > *printerFlags.getNewlineAfterAttrLimit()) {

// Increase indent to match the visually match the "{ " below.
//currentIndent += 2;
// currentIndent += 2;

separator.clear();
separator.reserve(currentIndent + 2);
Expand All @@ -2605,7 +2639,8 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,

// Otherwise, print them all out in braces.
interleave(
filteredAttrs, [&](NamedAttribute attr) { printNamedAttribute(attr); },
filteredAttrs,
[&](NamedAttribute attr) { printNamedAttribute(attr, separator); },
separator);
os << '}';
};
Expand All @@ -2623,7 +2658,8 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr,
SmallString<16> separator) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);

Expand All @@ -2632,7 +2668,8 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;

os << " = ";
printAttribute(attr.getValue());
printAttribute(attr.getValue(), /*typeElision*/ AttrTypeElision::Never,
separator);
}

void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
Expand Down Expand Up @@ -3026,7 +3063,8 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs, currentIndent + indentWidth);
Impl::printOptionalAttrDict(attrs, elidedAttrs,
currentIndent + indentWidth);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/IR/mlir-newline-after-attr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,18 @@
// CHECK: foo.dense_attr =
// CHECK: foo.second_attr =
"test.op"() {foo.dense_attr = dense<1> : tensor<3xi32>, foo.second_attr = dense<2> : tensor<3xi32>} : () -> ()

// CHECK: Operands = [{foo.vect_attr_1_count = dense<1> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_1_end = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_1_start = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_count = dense<1> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_end = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_start = dense<0> : vector<3xindex>},
// CHECK-NEXT: {foo.vect_attr_1_count = dense<1> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_1_end = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_1_start = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_count = dense<1> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_end = dense<0> : vector<3xindex>,
// CHECK-NEXT: foo.vect_attr_2_start = dense<0> : vector<3xindex>}],
"test.op"() {foo.dense_attr = dense<1> : tensor<3xi32>, foo.second_attr = dense<2> : tensor<3xi32>, Operands = [{foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}, {foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}]} : () -> ()