diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index fca1f3fb..db7c72f2 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -261,6 +261,11 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { let summary = "An opaque kernel"; let description = [{ The `xten_nn.kernel` operation defines an opaque computation with a name. + Example: + ``` + %c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> + %d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + ``` }]; let arguments = (ins @@ -269,7 +274,7 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { ); let results = (outs Variadic:$results); - let assemblyFormat = [{ $name custom(type($arguments), $arguments) attr-dict `->` type($results) }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index b8e375dd..f16ec614 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -181,45 +181,70 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) { /// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` /// /// This method is used by the tablegen assembly format for the kernel op. -static ParseResult parseKernelArgumentList( - OpAsmParser &parser, SmallVectorImpl &types, - SmallVectorImpl &arguments) { - if (parser.parseLParen()) - return failure(); - - if (succeeded(parser.parseOptionalRParen())) - return success(); - - while(true) { - OpAsmParser::UnresolvedOperand argument; - Type type; - if (parser.parseOperand(argument) || - parser.parseColon() || - parser.parseType(type)) - return failure(); - - types.push_back(type); - arguments.push_back(argument); - - if (succeeded(parser.parseOptionalRParen())) - return success(); +static ParseResult parseKernelArgumentList(OpAsmParser &p, + SmallVectorImpl &operands) { + return p.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, + [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + Type type; + if (p.parseOperand(operand)) + return failure(); + if (p.parseOptionalColon()) + return p.emitError(p.getCurrentLocation(), + "expected ':`, (argument format is val : type)"); - if (parser.parseComma()) - return failure(); - } + if (p.parseType(type) || p.resolveOperand(operand, type, operands)) + return failure(); + return success(); + }, + " in argument list"); } /// Prints a list of ssa values with their types. /// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` /// /// This method is used by the tablegen assembly format for the kernel op. -static void printKernelArgumentList(OpAsmPrinter &printer, Operation *op, - TypeRange types, - OperandRange arguments) { - printer << "("; - llvm::interleaveComma(llvm::zip(arguments, types), printer, - [&](const auto &a) { printer << get<0>(a) << " : " << get<1>(a); }); - printer << ")"; +static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, + OperandRange arguments) { + p << "("; + llvm::interleaveComma(llvm::zip(arguments, types), p, [&](const auto &a) { + p << get<0>(a) << " : " << get<1>(a); + }); + p << ")"; +} + +// Parse +// $name custom(type($arguments), $arguments) attr-dict +// `->` type($results) +ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { + StringAttr name; + if (p.parseAttribute(name, "name", result.attributes)) + return failure(); + + if (parseKernelArgumentList(p, result.operands) || + p.parseOptionalAttrDict(result.attributes) || p.parseArrow() || + p.parseTypeList(result.types)) + return failure(); + + return success(); +} + +// Parse +// $name custom(type($arguments), $arguments) attr-dict +// `->` type($results) +void KernelOp::print(OpAsmPrinter &p) { + p << ' '; + p << getNameAttr(); + p << ' '; + printKernelArgumentList(p, getOperandTypes(), getOperands()); + p << ' '; + SmallVector elidedAttrs = {"name"}; + p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); + if (getOperation()->getAttrs().size() > elidedAttrs.size()) + p << ' '; + p << "-> "; + p << getResultTypes(); } #define GET_OP_CLASSES diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 3599fb40..78d9bab4 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -34,13 +34,20 @@ func.func @kernel_missing_parenthesis() { // ----- -func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { - // expected-error@+1 {{expected ':'}} +func.func @kernel_missing_colon(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected ':`, (argument format is val : type)}} %a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64> } // ----- +func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected non-function type}} + %a = xten_nn.kernel "myKernel" (%arg0 : ) -> tensor<2xi64> +} + +// ----- + func.func @kernel_trailing_comma(%arg0: i8) { // expected-error@+1 {{expected SSA operand}} %a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64> @@ -49,7 +56,7 @@ func.func @kernel_trailing_comma(%arg0: i8) { // ----- func.func @kernel_missing_name() { - // expected-error@+1 {{custom op 'xten_nn.kernel' invalid kind of attribute specified}} + // expected-error@+1 {{'xten_nn.kernel' invalid kind of attribute specified}} %b = xten_nn.kernel () -> tensor<2xi64> return } diff --git a/tools/aten-opt/aten-opt.cpp b/tools/aten-opt/aten-opt.cpp index b0b91048..e9e45e27 100644 --- a/tools/aten-opt/aten-opt.cpp +++ b/tools/aten-opt/aten-opt.cpp @@ -52,7 +52,9 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); mlir::registerAllDialects(registry); - registry.insert(); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry));