From 27825d89641a603abc8cf4f7dc6289799a3e59f8 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 1 Oct 2024 13:50:08 +0100 Subject: [PATCH 1/9] Add kernel instantiation arguments --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 4 ++- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 35 ++++++++++++++++++--- test/Dialect/XTenNN/ops.mlir | 2 ++ 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 8c61b8be..5af78e41 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -265,12 +265,14 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { ``` %c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> %d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32} {attr = 4 : i32} -> tensor<2xi64> ``` }]; let arguments = (ins Variadic:$arguments, - StrAttr:$name + StrAttr:$name, + OptionalAttr:$instantiation_args ); let results = (outs Variadic:$results); diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index cc4c67af..b061c918 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,16 +10,21 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" @@ -222,8 +227,18 @@ ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { if (p.parseAttribute(name, "name", result.attributes)) return failure(); - if (parseKernelArgumentList(p, result.operands) || - p.parseOptionalAttrDict(result.attributes)) + if (parseKernelArgumentList(p, result.operands)) + return failure(); + + if(succeeded(p.parseOptionalKeyword("instantiation_args"))) { + NamedAttrList instantiationArgs; + if(p.parseOptionalAttrDict(instantiationArgs)) + return failure(); + DictionaryAttr dictAttr = DictionaryAttr::get(p.getContext(), instantiationArgs); + result.addAttribute("instantiation_args", dictAttr); + } + + if(p.parseOptionalAttrDict(result.attributes)) return failure(); // If the op has no results, the `-> type($results)` is absent. @@ -245,9 +260,21 @@ void KernelOp::print(OpAsmPrinter &p) { p << ' '; printKernelArgumentList(p, getOperandTypes(), getOperands()); p << ' '; - SmallVector elidedAttrs = {"name"}; + auto instantiationArgs = getInstantiationArgs(); + if(instantiationArgs != std::nullopt && !(instantiationArgs->empty())) { + p << "instantiation_args "; + p.printOptionalAttrDict(instantiationArgs->getValue()); + p << ' '; + } + SmallVector elidedAttrs = {"name", "instantiation_args"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - if (getOperation()->getAttrs().size() > elidedAttrs.size()) + if (llvm::any_of( + getOperation()->getAttrs(), [&elidedAttrs](NamedAttribute a) { + auto name = a.getName(); + return llvm::any_of(elidedAttrs, [&name](StringRef elidedName) { + return name == elidedName; + }); + })) p << ' '; if (getNumResults()) { p << "-> "; diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index ce152adf..94451555 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -34,6 +34,8 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> %d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> return } From 2d6bab29de970b7ae7d9a7bb8765140fa9e5ba1f Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 12:49:58 +0100 Subject: [PATCH 2/9] Add instantiation args and names as two separate lists --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 4 +- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 111 +++++++++++++++++--- test/Dialect/XTenNN/ops.mlir | 2 + test/Dialect/XTenNN/ops_invalid.mlir | 16 +++ 4 files changed, 116 insertions(+), 17 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 5af78e41..55220c5d 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -272,11 +272,13 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { let arguments = (ins Variadic:$arguments, StrAttr:$name, - OptionalAttr:$instantiation_args + OptionalAttr:$instantiation_args, + OptionalAttr:$instantiation_arg_names ); let results = (outs Variadic:$results); let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index b061c918..466b627e 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,13 +10,14 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" +#include "xten/Dialect/XTenNN/IR/XTenNNOps.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" @@ -24,14 +25,16 @@ #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "xten/Dialect/XTenNN/IR/XTenNN.h" #include "xten/Dialect/XTenNN/IR/XTenNNBase.h" -#include "xten/Dialect/XTenNN/IR/XTenNNOps.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" + #include using namespace mlir; @@ -219,6 +222,61 @@ static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, p << ")"; } +static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, + SmallVector &values, + SmallVector &names) { + if (failed(p.parseLBrace())) + return failure(); + + if (failed(p.parseCommaSeparatedList([&p, &names, &values]() { + std::string name; + bool hasName = false; + if (succeeded(p.parseOptionalKeywordOrString(&name))) { + hasName = true; + if (failed(p.parseEqual())) + return failure(); + } + Attribute attr; + auto res = p.parseOptionalAttribute(attr); + if (res.has_value() && succeeded(*res)) { + if (hasName) + names.push_back(StringAttr::get(p.getContext(), name)); + values.push_back(attr); + } + if (res.has_value() && failed(*res)) + return failure(); + + return success(); + }))) { + return failure(); + } + + if (failed(p.parseRBrace())) + return failure(); + + return success(); +} + +static void +printKernelInstantiationArgs(OpAsmPrinter &p, + ArrayRef instantiationArgs, + ArrayRef instantiationArgNames) { + if (!instantiationArgs.empty()) { + p << "instantiation_args {"; + auto zipped = llvm::zip_longest(instantiationArgNames, instantiationArgs); + for (auto iter = zipped.begin(); iter != zipped.end(); ++iter) { + if (iter != zipped.begin()) + p << ", "; + auto [name, value] = *iter; + if (name) + p << *name << " = "; + if (value) + p.printAttribute(*value); + } + p << '}'; + } +} + // Parse // $name custom(type($arguments), $arguments) attr-dict // `->` type($results) @@ -230,15 +288,20 @@ ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { if (parseKernelArgumentList(p, result.operands)) return failure(); - if(succeeded(p.parseOptionalKeyword("instantiation_args"))) { - NamedAttrList instantiationArgs; - if(p.parseOptionalAttrDict(instantiationArgs)) + if (succeeded(p.parseOptionalKeyword("instantiation_args"))) { + SmallVector values; + SmallVector names; + if (failed(parseKernelInstantiationArgs(p, values, names))) return failure(); - DictionaryAttr dictAttr = DictionaryAttr::get(p.getContext(), instantiationArgs); - result.addAttribute("instantiation_args", dictAttr); + result.addAttribute("instantiation_args", + ArrayAttr::get(p.getContext(), values)); + if (!names.empty()) { + result.addAttribute("instantiation_arg_names", + ArrayAttr::get(p.getContext(), names)); + } } - if(p.parseOptionalAttrDict(result.attributes)) + if (p.parseOptionalAttrDict(result.attributes)) return failure(); // If the op has no results, the `-> type($results)` is absent. @@ -261,12 +324,17 @@ void KernelOp::print(OpAsmPrinter &p) { printKernelArgumentList(p, getOperandTypes(), getOperands()); p << ' '; auto instantiationArgs = getInstantiationArgs(); - if(instantiationArgs != std::nullopt && !(instantiationArgs->empty())) { - p << "instantiation_args "; - p.printOptionalAttrDict(instantiationArgs->getValue()); + auto instantiationArgNames = getInstantiationArgNames(); + if (instantiationArgs != std::nullopt) { + printKernelInstantiationArgs(p, instantiationArgs->getValue(), + (instantiationArgNames == std::nullopt) + ? ArrayRef() + : instantiationArgNames->getValue()); p << ' '; } - SmallVector elidedAttrs = {"name", "instantiation_args"}; + + SmallVector elidedAttrs = {"name", "instantiation_args", + "instantiation_arg_names"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); if (llvm::any_of( getOperation()->getAttrs(), [&elidedAttrs](NamedAttribute a) { @@ -282,6 +350,19 @@ void KernelOp::print(OpAsmPrinter &p) { } } +LogicalResult KernelOp::verify() { + if (getInstantiationArgNames().has_value()) { + if (!getInstantiationArgs().has_value()) + return emitOpError("cannot have instantiation arg names without instantiation args"); + if (!(getInstantiationArgNamesAttr().empty() || + getInstantiationArgNamesAttr().size() == + getInstantiationArgsAttr().size())) + return emitOpError("instantiation arg names must be either empty or as long as instantiation args"); + } + + return success(); +} + #define GET_OP_CLASSES #include "xten/Dialect/XTenNN/IR/XTenNNOps.cpp.inc" @@ -293,9 +374,7 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) { return parseEnclaveOp(p, result); } -void SubgraphOp::print(OpAsmPrinter &p) { - printEnclaveOp(p, *this); -} +void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); } LogicalResult SubgraphOp::verify() { Block *optBody = this->getOptionalEnclaveBody(); diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 94451555..7b581e75 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -36,6 +36,8 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> + %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> return } diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index ffd91190..3fdd6303 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -70,6 +70,22 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) { // ----- +func.func @kernel_instantiation_list_different_length(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{instantiation arg names must be either empty or as long as instantiation args}} + %x = xten_nn.kernel "myKernel" () instantiation_args {N = 42 : i32, 51 : index} -> i32 + return +} + +// ----- + +func.func @kernel_instantiation_list_non_empty(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{cannot have instantiation arg names without instantiation args}} + %x = xten_nn.kernel "myKernel" () {instantiation_arg_names = ["N"]} -> i32 + return +} + +// ----- + func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { %k = arith.constant 7 : i64 // expected-error@+2 {{failed to infer returned types}} From da65e6a5420e5805589a4b828e35a5ba35504a94 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 12:57:36 +0100 Subject: [PATCH 3/9] Formatting, comments --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 466b627e..9020fe11 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -222,9 +222,11 @@ static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, p << ")"; } +// Parse +// {((name = )?value, )*((name = )?value)} static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, - SmallVector &values, - SmallVector &names) { + SmallVector &values, + SmallVector &names) { if (failed(p.parseLBrace())) return failure(); @@ -257,6 +259,8 @@ static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, return success(); } +// Print +// instantiation_args {((name = )?value, )*((name = )?value)} static void printKernelInstantiationArgs(OpAsmPrinter &p, ArrayRef instantiationArgs, @@ -278,7 +282,8 @@ printKernelInstantiationArgs(OpAsmPrinter &p, } // Parse -// $name custom(type($arguments), $arguments) attr-dict +// $name custom(type($arguments), $arguments) +// (instantiation_args custom)? attr-dict // `->` type($results) ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { StringAttr name; @@ -314,8 +319,9 @@ ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { return success(); } -// Parse -// $name custom(type($arguments), $arguments) attr-dict +// Print +// $name custom(type($arguments), $arguments) +// (instantiation_args custom)? attr-dict // `->` type($results) void KernelOp::print(OpAsmPrinter &p) { p << ' '; @@ -353,11 +359,13 @@ void KernelOp::print(OpAsmPrinter &p) { LogicalResult KernelOp::verify() { if (getInstantiationArgNames().has_value()) { if (!getInstantiationArgs().has_value()) - return emitOpError("cannot have instantiation arg names without instantiation args"); + return emitOpError( + "cannot have instantiation arg names without instantiation args"); if (!(getInstantiationArgNamesAttr().empty() || getInstantiationArgNamesAttr().size() == getInstantiationArgsAttr().size())) - return emitOpError("instantiation arg names must be either empty or as long as instantiation args"); + return emitOpError("instantiation arg names must be either empty or as " + "long as instantiation args"); } return success(); From 7f37c7cb9abb855763309cd7d634babe285743fb Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 13:00:43 +0100 Subject: [PATCH 4/9] Add example without types --- test/Dialect/XTenNN/ops.mlir | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 7b581e75..97aae9db 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -38,6 +38,8 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> + %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56} {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56} {attr = 4 : i32} -> tensor<2xi64> return } From 5b9ba2c159d97dff383272f43404b6225728d46f Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 13:03:05 +0100 Subject: [PATCH 5/9] Add test with untyped parameters --- test/Dialect/XTenNN/ops.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 97aae9db..2207fb24 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -35,11 +35,11 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { %d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> - // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {"N" = 42 : i32, "idx" = 56 : index} {attr = 4 : i32} -> tensor<2xi64> %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> - %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56} {attr = 4 : i32} -> tensor<2xi64> - // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56} {attr = 4 : i32} -> tensor<2xi64> + %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56, 1.0} {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i64, 56 : i64, 1.000000e+00 : f64} {attr = 4 : i32} -> tensor<2xi64> return } From 1a5a5420af89db6a1824580fb9c654f847cb15b1 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 13:39:13 +0100 Subject: [PATCH 6/9] Use square brackets, expect quoted string --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 20 +++++++++++--------- test/Dialect/XTenNN/ops.mlir | 12 ++++++------ test/Dialect/XTenNN/ops_invalid.mlir | 11 ++++++++++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 9020fe11..00b0266b 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -223,17 +223,17 @@ static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, } // Parse -// {((name = )?value, )*((name = )?value)} +// [((name = )?value, )*((name = )?value)] static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, SmallVector &values, SmallVector &names) { - if (failed(p.parseLBrace())) + if (failed(p.parseLSquare())) return failure(); if (failed(p.parseCommaSeparatedList([&p, &names, &values]() { std::string name; bool hasName = false; - if (succeeded(p.parseOptionalKeywordOrString(&name))) { + if (succeeded(p.parseOptionalString(&name))) { hasName = true; if (failed(p.parseEqual())) return failure(); @@ -253,20 +253,20 @@ static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, return failure(); } - if (failed(p.parseRBrace())) + if (failed(p.parseRSquare())) return failure(); return success(); } // Print -// instantiation_args {((name = )?value, )*((name = )?value)} +// instantiation_args [((name = )?value, )*((name = )?value)] static void printKernelInstantiationArgs(OpAsmPrinter &p, ArrayRef instantiationArgs, ArrayRef instantiationArgNames) { if (!instantiationArgs.empty()) { - p << "instantiation_args {"; + p << "instantiation_args ["; auto zipped = llvm::zip_longest(instantiationArgNames, instantiationArgs); for (auto iter = zipped.begin(); iter != zipped.end(); ++iter) { if (iter != zipped.begin()) @@ -277,7 +277,7 @@ printKernelInstantiationArgs(OpAsmPrinter &p, if (value) p.printAttribute(*value); } - p << '}'; + p << ']'; } } @@ -358,14 +358,16 @@ void KernelOp::print(OpAsmPrinter &p) { LogicalResult KernelOp::verify() { if (getInstantiationArgNames().has_value()) { - if (!getInstantiationArgs().has_value()) + if (!getInstantiationArgs().has_value()) { return emitOpError( "cannot have instantiation arg names without instantiation args"); + } if (!(getInstantiationArgNamesAttr().empty() || getInstantiationArgNamesAttr().size() == - getInstantiationArgsAttr().size())) + getInstantiationArgsAttr().size())) { return emitOpError("instantiation arg names must be either empty or as " "long as instantiation args"); + } } return success(); diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 2207fb24..b6ea7969 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -34,12 +34,12 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> %d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> - %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32, idx = 56 : index} {attr = 4 : i32} -> tensor<2xi64> - // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {"N" = 42 : i32, "idx" = 56 : index} {attr = 4 : i32} -> tensor<2xi64> - %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> - // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i32, 56 : index} {attr = 4 : i32} -> tensor<2xi64> - %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42, 56, 1.0} {attr = 4 : i32} -> tensor<2xi64> - // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {42 : i64, 56 : i64, 1.000000e+00 : f64} {attr = 4 : i32} -> tensor<2xi64> + %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64> + %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64> + %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42, 56, 1.0] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i64, 56 : i64, 1.000000e+00 : f64] {attr = 4 : i32} -> tensor<2xi64> return } diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 3fdd6303..6815f172 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -72,7 +72,7 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) { func.func @kernel_instantiation_list_different_length(%arg0: i8, %arg1: i8) { // expected-error@+1 {{instantiation arg names must be either empty or as long as instantiation args}} - %x = xten_nn.kernel "myKernel" () instantiation_args {N = 42 : i32, 51 : index} -> i32 + %x = xten_nn.kernel "myKernel" () instantiation_args ["N" = 42 : i32, 51 : index] -> i32 return } @@ -84,6 +84,15 @@ func.func @kernel_instantiation_list_non_empty(%arg0: i8, %arg1: i8) { return } +// ----- + +func.func @kernel_instantiation_list_non_quoted(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected ']'}} + %x = xten_nn.kernel "myKernel" () instantiation_args [N = 42 : i32] -> i32 + return +} + + // ----- func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { From 6d557e5d7f31caf2fdccc20560d85fad58b91bb3 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 13:42:53 +0100 Subject: [PATCH 7/9] Move push_back --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 00b0266b..6876716e 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -232,17 +232,14 @@ static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, if (failed(p.parseCommaSeparatedList([&p, &names, &values]() { std::string name; - bool hasName = false; if (succeeded(p.parseOptionalString(&name))) { - hasName = true; + names.push_back(StringAttr::get(p.getContext(), name)); if (failed(p.parseEqual())) return failure(); } Attribute attr; auto res = p.parseOptionalAttribute(attr); if (res.has_value() && succeeded(*res)) { - if (hasName) - names.push_back(StringAttr::get(p.getContext(), name)); values.push_back(attr); } if (res.has_value() && failed(*res)) From dbda7d4853f6297bb642ad6f45ce88381d92395b Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 13:49:19 +0100 Subject: [PATCH 8/9] clang-format --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 6876716e..2af2a988 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -233,7 +233,7 @@ static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, if (failed(p.parseCommaSeparatedList([&p, &names, &values]() { std::string name; if (succeeded(p.parseOptionalString(&name))) { - names.push_back(StringAttr::get(p.getContext(), name)); + names.push_back(StringAttr::get(p.getContext(), name)); if (failed(p.parseEqual())) return failure(); } From bb61b2febfe3571fe683a007f596acf50fe67efc Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 2 Oct 2024 16:32:19 +0100 Subject: [PATCH 9/9] Add builder without optional arguments --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 55220c5d..364655f8 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -277,6 +277,17 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { ); let results = (outs Variadic:$results); + let builders = [ + OpBuilder<(ins + "::mlir::TypeRange":$results, + "::mlir::ValueRange":$arguments, + "::llvm::StringRef":$name), [{ + build($_builder, $_state, results, arguments, name, + ::mlir::ArrayAttr(), ::mlir::ArrayAttr()); + }] + > + ]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; }