Skip to content

Commit

Permalink
REview comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 9, 2024
1 parent 24719b7 commit ca2c072
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 37 deletions.
7 changes: 6 additions & 1 deletion include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> {
let summary = "An opaque kernel";
let description = [{
The `xten_nn.kernel` operation defines an opaque computation with a name.
Example:
```
%c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
%d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
```
}];

let arguments = (ins
Expand All @@ -269,7 +274,7 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> {
);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{ $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict `->` type($results) }];
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
89 changes: 57 additions & 32 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,45 +181,70 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)`
///
/// This method is used by the tablegen assembly format for the kernel op.
static ParseResult parseKernelArgumentList(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &arguments) {
if (parser.parseLParen())
return failure();

if (succeeded(parser.parseOptionalRParen()))
return success();

while(true) {
OpAsmParser::UnresolvedOperand argument;
Type type;
if (parser.parseOperand(argument) ||
parser.parseColon() ||
parser.parseType(type))
return failure();

types.push_back(type);
arguments.push_back(argument);

if (succeeded(parser.parseOptionalRParen()))
return success();
static ParseResult parseKernelArgumentList(OpAsmParser &p,
SmallVectorImpl<Value> &operands) {
return p.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren,
[&]() -> ParseResult {
OpAsmParser::UnresolvedOperand operand;
Type type;
if (p.parseOperand(operand))
return failure();
if (p.parseOptionalColon())
return p.emitError(p.getCurrentLocation(),
"expected ':`, (argument format is val : type)");

if (parser.parseComma())
return failure();
}
if (p.parseType(type) || p.resolveOperand(operand, type, operands))
return failure();
return success();
},
" in argument list");
}

/// Prints a list of ssa values with their types.
/// `(` (ssa-id `:` type (`,` ssa-id `:` type)*)? `)`
///
/// This method is used by the tablegen assembly format for the kernel op.
static void printKernelArgumentList(OpAsmPrinter &printer, Operation *op,
TypeRange types,
OperandRange arguments) {
printer << "(";
llvm::interleaveComma(llvm::zip(arguments, types), printer,
[&](const auto &a) { printer << get<0>(a) << " : " << get<1>(a); });
printer << ")";
static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types,
OperandRange arguments) {
p << "(";
llvm::interleaveComma(llvm::zip(arguments, types), p, [&](const auto &a) {
p << get<0>(a) << " : " << get<1>(a);
});
p << ")";
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// `->` type($results)
ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) {
StringAttr name;
if (p.parseAttribute(name, "name", result.attributes))
return failure();

if (parseKernelArgumentList(p, result.operands) ||
p.parseOptionalAttrDict(result.attributes) || p.parseArrow() ||
p.parseTypeList(result.types))
return failure();

return success();
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// `->` type($results)
void KernelOp::print(OpAsmPrinter &p) {
p << ' ';
p << getNameAttr();
p << ' ';
printKernelArgumentList(p, getOperandTypes(), getOperands());
p << ' ';
SmallVector<StringRef> elidedAttrs = {"name"};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
if (getOperation()->getAttrs().size() > elidedAttrs.size())
p << ' ';
p << "-> ";
p << getResultTypes();
}

#define GET_OP_CLASSES
Expand Down
13 changes: 10 additions & 3 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ func.func @kernel_missing_parenthesis() {

// -----

func.func @kernel_missing_type(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected ':'}}
func.func @kernel_missing_colon(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected ':`, (argument format is val : type)}}
%a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64>
}

// -----

func.func @kernel_missing_type(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected non-function type}}
%a = xten_nn.kernel "myKernel" (%arg0 : ) -> tensor<2xi64>
}

// -----

func.func @kernel_trailing_comma(%arg0: i8) {
// expected-error@+1 {{expected SSA operand}}
%a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64>
Expand All @@ -49,7 +56,7 @@ func.func @kernel_trailing_comma(%arg0: i8) {
// -----

func.func @kernel_missing_name() {
// expected-error@+1 {{custom op 'xten_nn.kernel' invalid kind of attribute specified}}
// expected-error@+1 {{'xten_nn.kernel' invalid kind of attribute specified}}
%b = xten_nn.kernel () -> tensor<2xi64>
return
}
4 changes: 3 additions & 1 deletion tools/aten-opt/aten-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ int main(int argc, char **argv) {
DialectRegistry registry;
registerAllDialects(registry);
mlir::registerAllDialects(registry);
registry.insert<amd::xten_nn::XTenNNDialect>();
registry.insert<amd::xten_nn::XTenNNDialect,
torch::Torch::TorchDialect,
torch::TorchConversion::TorchConversionDialect>();

return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n",
registry));
Expand Down

0 comments on commit ca2c072

Please sign in to comment.