From 69b78201b79e96609d714ab1fcbc605fa0196d71 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 7 Aug 2024 16:31:02 +0200 Subject: [PATCH 1/4] Add xten_nn.kernel op --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 16 +++++++ lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 52 ++++++++++++++++++++- test/Dialect/XTenNN/ops.mlir | 16 +++++++ test/Dialect/XTenNN/ops_invalid.mlir | 15 ++++++ tools/aten-opt/aten-opt.cpp | 4 +- 5 files changed, 99 insertions(+), 4 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 840e0e11..944d90ae 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -257,6 +257,22 @@ def XTenNN_LoadExternalConstOp: XTenNN_Op<"load_external_const", [ let assemblyFormat = [{ attr-dict `->` type($output) }]; } +def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { + let summary = "Defines an opaque kernel"; + let description = [{ + The `xten_nn.kernel` operation defines an opaque computation with a name. + ``` + }]; + + let arguments = (ins + Variadic:$arguments, + StrAttr:$name + ); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ $name custom(type($arguments), $arguments) attr-dict `->` type($results) }]; +} + //===----------------------------------------------------------------------===// // Ops that are missing from the TOSA standard //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 83afc80a..b8e375dd 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -172,6 +172,56 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) { }; } + +//===----------------------------------------------------------------------===// +// KernelOp +//===----------------------------------------------------------------------===// + +/// Parses a list of ssa values with their types. +/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` +/// +/// This method is used by the tablegen assembly format for the kernel op. +static ParseResult parseKernelArgumentList( + OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &arguments) { + if (parser.parseLParen()) + return failure(); + + if (succeeded(parser.parseOptionalRParen())) + return success(); + + while(true) { + OpAsmParser::UnresolvedOperand argument; + Type type; + if (parser.parseOperand(argument) || + parser.parseColon() || + parser.parseType(type)) + return failure(); + + types.push_back(type); + arguments.push_back(argument); + + if (succeeded(parser.parseOptionalRParen())) + return success(); + + if (parser.parseComma()) + return failure(); + } +} + +/// Prints a list of ssa values with their types. +/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` +/// +/// This method is used by the tablegen assembly format for the kernel op. +static void printKernelArgumentList(OpAsmPrinter &printer, Operation *op, + TypeRange types, + OperandRange arguments) { + printer << "("; + llvm::interleaveComma(llvm::zip(arguments, types), printer, + [&](const auto &a) { printer << get<0>(a) << " : " << get<1>(a); }); + printer << ")"; +} + #define GET_OP_CLASSES #include "xten/Dialect/XTenNN/IR/XTenNNOps.cpp.inc" @@ -435,4 +485,4 @@ LogicalResult amd::xten_nn::ResizeOp::verify() { } return success(); -} \ No newline at end of file +} diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 884a1077..a824fb1f 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -18,3 +18,19 @@ func.func @subgraph_empty(%arg0: tensor<2xi64>) -> tensor<2xi64> { %sum = xten_nn.subgraph (%arg0 : tensor<2xi64>) -> tensor<2xi64> return %sum : tensor<2xi64> } + + +// ----- + +// CHECK-LABEL: kernel +func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { + %a = xten_nn.kernel "myKernel" () -> tensor<2xi64> + // CHECK: xten_nn.kernel "myKernel" () -> tensor<2xi64> + %b = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) -> tensor<2xi64> + // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) -> tensor<2xi64> + %c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> + %d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + return +} diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 70b1dd2f..1a4fe949 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -24,3 +24,18 @@ func.func @mish_int(%arg0: tensor<1x10xi4>) -> tensor<1x10xi4> { %0 = xten_nn.mish %arg0 : (tensor<1x10xi4>) -> tensor<1x10xi4> return %0 : tensor<1x10xi4> } + +// ----- + +func.func @kernel_missing_parenthesis() { + // expected-error@+1 {{expected '('}} + %a = xten_nn.kernel "myKernel" -> tensor<2xi64> +} + +// ----- + +func.func @kernel_missing_name() { + // expected-error@+1 {{custom op 'xten_nn.kernel' invalid kind of attribute specified}} + %b = xten_nn.kernel () -> tensor<2xi64> + return +} diff --git a/tools/aten-opt/aten-opt.cpp b/tools/aten-opt/aten-opt.cpp index e9e45e27..b0b91048 100644 --- a/tools/aten-opt/aten-opt.cpp +++ b/tools/aten-opt/aten-opt.cpp @@ -52,9 +52,7 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); mlir::registerAllDialects(registry); - registry.insert(); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry)); From 1e1ab449fdfba605582dd9d1732c7f27c36798af Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 7 Aug 2024 16:34:12 +0200 Subject: [PATCH 2/4] More parsing tests --- test/Dialect/XTenNN/ops_invalid.mlir | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 1a4fe949..3599fb40 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -34,6 +34,20 @@ func.func @kernel_missing_parenthesis() { // ----- +func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected ':'}} + %a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64> +} + +// ----- + +func.func @kernel_trailing_comma(%arg0: i8) { + // expected-error@+1 {{expected SSA operand}} + %a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64> +} + +// ----- + func.func @kernel_missing_name() { // expected-error@+1 {{custom op 'xten_nn.kernel' invalid kind of attribute specified}} %b = xten_nn.kernel () -> tensor<2xi64> From 24719b788e8e823d5c8fe4b2d9f44787017e0cc2 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 7 Aug 2024 16:39:16 +0200 Subject: [PATCH 3/4] Fix description --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 944d90ae..fca1f3fb 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -258,10 +258,9 @@ def XTenNN_LoadExternalConstOp: XTenNN_Op<"load_external_const", [ } def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { - let summary = "Defines an opaque kernel"; + let summary = "An opaque kernel"; let description = [{ The `xten_nn.kernel` operation defines an opaque computation with a name. - ``` }]; let arguments = (ins From ca2c07283159c857c2e770ebf7b24ea75bd4fe19 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 9 Aug 2024 11:22:35 +0200 Subject: [PATCH 4/4] REview comments --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 7 +- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 89 +++++++++++++-------- test/Dialect/XTenNN/ops_invalid.mlir | 13 ++- tools/aten-opt/aten-opt.cpp | 4 +- 4 files changed, 76 insertions(+), 37 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index fca1f3fb..db7c72f2 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -261,6 +261,11 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { let summary = "An opaque kernel"; let description = [{ The `xten_nn.kernel` operation defines an opaque computation with a name. + Example: + ``` + %c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> + %d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + ``` }]; let arguments = (ins @@ -269,7 +274,7 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { ); let results = (outs Variadic:$results); - let assemblyFormat = [{ $name custom(type($arguments), $arguments) attr-dict `->` type($results) }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index b8e375dd..f16ec614 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -181,45 +181,70 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) { /// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` /// /// This method is used by the tablegen assembly format for the kernel op. -static ParseResult parseKernelArgumentList( - OpAsmParser &parser, SmallVectorImpl &types, - SmallVectorImpl &arguments) { - if (parser.parseLParen()) - return failure(); - - if (succeeded(parser.parseOptionalRParen())) - return success(); - - while(true) { - OpAsmParser::UnresolvedOperand argument; - Type type; - if (parser.parseOperand(argument) || - parser.parseColon() || - parser.parseType(type)) - return failure(); - - types.push_back(type); - arguments.push_back(argument); - - if (succeeded(parser.parseOptionalRParen())) - return success(); +static ParseResult parseKernelArgumentList(OpAsmParser &p, + SmallVectorImpl &operands) { + return p.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, + [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + Type type; + if (p.parseOperand(operand)) + return failure(); + if (p.parseOptionalColon()) + return p.emitError(p.getCurrentLocation(), + "expected ':`, (argument format is val : type)"); - if (parser.parseComma()) - return failure(); - } + if (p.parseType(type) || p.resolveOperand(operand, type, operands)) + return failure(); + return success(); + }, + " in argument list"); } /// Prints a list of ssa values with their types. /// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)` /// /// This method is used by the tablegen assembly format for the kernel op. -static void printKernelArgumentList(OpAsmPrinter &printer, Operation *op, - TypeRange types, - OperandRange arguments) { - printer << "("; - llvm::interleaveComma(llvm::zip(arguments, types), printer, - [&](const auto &a) { printer << get<0>(a) << " : " << get<1>(a); }); - printer << ")"; +static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, + OperandRange arguments) { + p << "("; + llvm::interleaveComma(llvm::zip(arguments, types), p, [&](const auto &a) { + p << get<0>(a) << " : " << get<1>(a); + }); + p << ")"; +} + +// Parse +// $name custom(type($arguments), $arguments) attr-dict +// `->` type($results) +ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { + StringAttr name; + if (p.parseAttribute(name, "name", result.attributes)) + return failure(); + + if (parseKernelArgumentList(p, result.operands) || + p.parseOptionalAttrDict(result.attributes) || p.parseArrow() || + p.parseTypeList(result.types)) + return failure(); + + return success(); +} + +// Parse +// $name custom(type($arguments), $arguments) attr-dict +// `->` type($results) +void KernelOp::print(OpAsmPrinter &p) { + p << ' '; + p << getNameAttr(); + p << ' '; + printKernelArgumentList(p, getOperandTypes(), getOperands()); + p << ' '; + SmallVector elidedAttrs = {"name"}; + p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); + if (getOperation()->getAttrs().size() > elidedAttrs.size()) + p << ' '; + p << "-> "; + p << getResultTypes(); } #define GET_OP_CLASSES diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 3599fb40..78d9bab4 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -34,13 +34,20 @@ func.func @kernel_missing_parenthesis() { // ----- -func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { - // expected-error@+1 {{expected ':'}} +func.func @kernel_missing_colon(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected ':`, (argument format is val : type)}} %a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64> } // ----- +func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected non-function type}} + %a = xten_nn.kernel "myKernel" (%arg0 : ) -> tensor<2xi64> +} + +// ----- + func.func @kernel_trailing_comma(%arg0: i8) { // expected-error@+1 {{expected SSA operand}} %a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64> @@ -49,7 +56,7 @@ func.func @kernel_trailing_comma(%arg0: i8) { // ----- func.func @kernel_missing_name() { - // expected-error@+1 {{custom op 'xten_nn.kernel' invalid kind of attribute specified}} + // expected-error@+1 {{'xten_nn.kernel' invalid kind of attribute specified}} %b = xten_nn.kernel () -> tensor<2xi64> return } diff --git a/tools/aten-opt/aten-opt.cpp b/tools/aten-opt/aten-opt.cpp index b0b91048..e9e45e27 100644 --- a/tools/aten-opt/aten-opt.cpp +++ b/tools/aten-opt/aten-opt.cpp @@ -52,7 +52,9 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); mlir::registerAllDialects(registry); - registry.insert(); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry));