Skip to content

Commit

Permalink
Merge pull request #60 from Xilinx/matthias.kernel
Browse files Browse the repository at this point in the history
Add xten_nn.kernel op
  • Loading branch information
mgehre-amd authored Aug 9, 2024
2 parents 4f0c873 + ca2c072 commit 1f540af
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 1 deletion.
20 changes: 20 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,26 @@ def XTenNN_LoadExternalConstOp: XTenNN_Op<"load_external_const", [
let assemblyFormat = [{ attr-dict `->` type($output) }];
}

def XTenNN_KernelOp : XTenNN_Op<"kernel", []> {
let summary = "An opaque kernel";
let description = [{
The `xten_nn.kernel` operation defines an opaque computation with a name.
Example:
```
%c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
%d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
```
}];

let arguments = (ins
Variadic<AnyType>:$arguments,
StrAttr:$name
);
let results = (outs Variadic<AnyType>:$results);

let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Ops that are missing from the TOSA standard
//===----------------------------------------------------------------------===//
Expand Down
77 changes: 76 additions & 1 deletion lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,81 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
};
}


//===----------------------------------------------------------------------===//
// KernelOp
//===----------------------------------------------------------------------===//

/// Parses a list of ssa values with their types.
/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)`
///
/// This method is used by the tablegen assembly format for the kernel op.
static ParseResult parseKernelArgumentList(OpAsmParser &p,
SmallVectorImpl<Value> &operands) {
return p.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren,
[&]() -> ParseResult {
OpAsmParser::UnresolvedOperand operand;
Type type;
if (p.parseOperand(operand))
return failure();
if (p.parseOptionalColon())
return p.emitError(p.getCurrentLocation(),
"expected ':`, (argument format is val : type)");

if (p.parseType(type) || p.resolveOperand(operand, type, operands))
return failure();
return success();
},
" in argument list");
}

/// Prints a list of ssa values with their types.
/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)`
///
/// This method is used by the tablegen assembly format for the kernel op.
static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types,
OperandRange arguments) {
p << "(";
llvm::interleaveComma(llvm::zip(arguments, types), p, [&](const auto &a) {
p << get<0>(a) << " : " << get<1>(a);
});
p << ")";
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// `->` type($results)
ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) {
StringAttr name;
if (p.parseAttribute(name, "name", result.attributes))
return failure();

if (parseKernelArgumentList(p, result.operands) ||
p.parseOptionalAttrDict(result.attributes) || p.parseArrow() ||
p.parseTypeList(result.types))
return failure();

return success();
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// `->` type($results)
void KernelOp::print(OpAsmPrinter &p) {
p << ' ';
p << getNameAttr();
p << ' ';
printKernelArgumentList(p, getOperandTypes(), getOperands());
p << ' ';
SmallVector<StringRef> elidedAttrs = {"name"};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
if (getOperation()->getAttrs().size() > elidedAttrs.size())
p << ' ';
p << "-> ";
p << getResultTypes();
}

#define GET_OP_CLASSES
#include "xten/Dialect/XTenNN/IR/XTenNNOps.cpp.inc"

Expand Down Expand Up @@ -435,4 +510,4 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {
}

return success();
}
}
16 changes: 16 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,19 @@ func.func @subgraph_empty(%arg0: tensor<2xi64>) -> tensor<2xi64> {
%sum = xten_nn.subgraph (%arg0 : tensor<2xi64>) -> tensor<2xi64>
return %sum : tensor<2xi64>
}


// -----

// CHECK-LABEL: kernel
func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {
%a = xten_nn.kernel "myKernel" () -> tensor<2xi64>
// CHECK: xten_nn.kernel "myKernel" () -> tensor<2xi64>
%b = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) -> tensor<2xi64>
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) -> tensor<2xi64>
%c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
%d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
return
}
36 changes: 36 additions & 0 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,39 @@ func.func @mish_int(%arg0: tensor<1x10xi4>) -> tensor<1x10xi4> {
%0 = xten_nn.mish %arg0 : (tensor<1x10xi4>) -> tensor<1x10xi4>
return %0 : tensor<1x10xi4>
}

// -----

func.func @kernel_missing_parenthesis() {
// expected-error@+1 {{expected '('}}
%a = xten_nn.kernel "myKernel" -> tensor<2xi64>
}

// -----

func.func @kernel_missing_colon(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected ':`, (argument format is val : type)}}
%a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64>
}

// -----

func.func @kernel_missing_type(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected non-function type}}
%a = xten_nn.kernel "myKernel" (%arg0 : ) -> tensor<2xi64>
}

// -----

func.func @kernel_trailing_comma(%arg0: i8) {
// expected-error@+1 {{expected SSA operand}}
%a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64>
}

// -----

func.func @kernel_missing_name() {
// expected-error@+1 {{'xten_nn.kernel' invalid kind of attribute specified}}
%b = xten_nn.kernel () -> tensor<2xi64>
return
}

0 comments on commit 1f540af

Please sign in to comment.