diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index c4581a5..76828f3 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -9,3 +9,4 @@ set(IREE_PACKAGE_ROOT_PREFIX "") set(IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}") add_subdirectory(openxla/compiler/nvgpu) +add_subdirectory(openxla/compiler/async) diff --git a/compiler/src/openxla/compiler/async/BUILD.bazel b/compiler/src/openxla/compiler/async/BUILD.bazel new file mode 100644 index 0000000..9890e39 --- /dev/null +++ b/compiler/src/openxla/compiler/async/BUILD.bazel @@ -0,0 +1,38 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("@iree_core//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_register_plugin") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "defs", + includes = ["../../.."], +) + +cc_library( + name = "registration", + srcs = [ + "PluginRegistration.cpp", + ], + deps = [ + ":defs", + "//compiler/src/openxla/compiler/async/Dialect/Async/IR", + "//compiler/src/openxla/compiler/async/Dialect/Async/Transforms", + "@iree_core//compiler/src/iree/compiler/PluginAPI", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +iree_compiler_register_plugin( + plugin_id = "openxla_async", + target = ":registration", +) diff --git a/compiler/src/openxla/compiler/async/CMakeLists.txt b/compiler/src/openxla/compiler/async/CMakeLists.txt new file mode 100644 index 0000000..3a1a80b --- /dev/null +++ b/compiler/src/openxla/compiler/async/CMakeLists.txt @@ -0,0 +1,44 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + defs + INCLUDES + "$" + "$" + PUBLIC +) + +iree_cc_library( + NAME + registration + SRCS + "PluginRegistration.cpp" + DEPS + ::defs + MLIRIR + MLIRPass + iree::compiler::PluginAPI + openxla::compiler::async::Dialect::Async::IR + openxla::compiler::async::Dialect::Async::Transforms + PUBLIC +) + +iree_compiler_register_plugin( + PLUGIN_ID + openxla_async + TARGET + ::registration +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/BUILD.bazel new file mode 100644 index 0000000..ff111ea --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/BUILD.bazel @@ -0,0 +1,11 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/CMakeLists.txt new file mode 100644 index 0000000..bb8dc1b --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/CMakeLists.txt @@ -0,0 +1,13 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/BUILD.bazel new file mode 100644 index 0000000..9757b60 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/BUILD.bazel @@ -0,0 +1,28 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "AsyncToRuntime", + srcs = ["ConvertAsyncToRuntime.cpp"], + hdrs = ["ConvertAsyncToRuntime.h"], + deps = [ + "//compiler/src/openxla/compiler/async/Dialect/Async/IR", + "@iree_core//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/CMakeLists.txt new file mode 100644 index 0000000..73e3d8e --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/Conversion/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + AsyncToRuntime + HDRS + "ConvertAsyncToRuntime.h" + SRCS + "ConvertAsyncToRuntime.cpp" + DEPS + LLVMSupport + MLIRFuncDialect + MLIRIR + MLIRPass + MLIRTransforms + iree::compiler::Dialect::Util::IR + openxla::compiler::async::Dialect::Async::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.cpp b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.cpp new file mode 100644 index 0000000..2d20bef --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.cpp @@ -0,0 +1,249 @@ +// Copyright 2023 The OpenXLA Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.h" + +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "openxla/compiler/async/Dialect/Async/IR/Async.h" + +namespace openxla::compiler::async { + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace { + +//===----------------------------------------------------------------------===// +// AsyncAPI for importing Async VM module functions +//===----------------------------------------------------------------------===// + +class AsyncAPI { + public: + // Import `@async.token.await` into the module + func::FuncOp getValueAwait(PatternRewriter &rewriter, ModuleOp module); + // Import `@async.value.query` into the module + func::FuncOp getValueQuery(PatternRewriter &rewriter, ModuleOp module); + // Import `@async.value.load.i32` into the module + func::FuncOp getValueLoadI32(PatternRewriter &rewriter, ModuleOp module); + // Import `@async.value.load.ref` into the module + func::FuncOp getValueLoadRef(PatternRewriter &rewriter, ModuleOp module); + + SymbolTable &symTable(ModuleOp module); + + bool isScalarType(Type type) { return type.isIntOrIndexOrFloat(); } + + private: + func::FuncOp addDecl(PatternRewriter &rewriter, ModuleOp module, + StringAttr name, FunctionType function_type); + SymbolTableCollection symTable_; +}; + +SymbolTable &AsyncAPI::symTable(ModuleOp module) { + return symTable_.getSymbolTable(module); +} + +func::FuncOp AsyncAPI::addDecl(PatternRewriter &rewriter, ModuleOp module, + StringAttr name, FunctionType function_type) { + if (auto fn = symTable_.lookupNearestSymbolFrom(module, name)) + return fn; + + ImplicitLocOpBuilder b(UnknownLoc::get(module->getContext()), rewriter); + b.setInsertionPointToEnd(module.getBody()); + + auto fn = b.create(name, function_type); + fn.setPrivate(); + symTable(module).insert(fn); + return fn; +} + +func::FuncOp AsyncAPI::getValueQuery(PatternRewriter &rewriter, + ModuleOp module) { + MLIRContext *ctx = module->getContext(); + SmallVector args{ValueType::get(ctx)}; + SmallVector rets{IntegerType::get(ctx, 32)}; + + auto functionType = FunctionType::get(ctx, args, rets); + + return addDecl(rewriter, module, StringAttr::get(ctx, "async.value.query"), + functionType); +} + +func::FuncOp AsyncAPI::getValueAwait(PatternRewriter &rewriter, + ModuleOp module) { + MLIRContext *ctx = module->getContext(); + SmallVector args{ValueType::get(ctx)}; + auto functionType = FunctionType::get(ctx, args, /*rets=*/{}); + + return addDecl(rewriter, module, StringAttr::get(ctx, "async.value.await"), + functionType); +} + +func::FuncOp AsyncAPI::getValueLoadI32(PatternRewriter &rewriter, + ModuleOp module) { + MLIRContext *ctx = module->getContext(); + SmallVector args{ValueType::get(ctx)}; + SmallVector rets{IntegerType::get(ctx, 32)}; + auto functionType = FunctionType::get(ctx, args, rets); + + return addDecl(rewriter, module, StringAttr::get(ctx, "async.value.load.i32"), + functionType); +} + +func::FuncOp AsyncAPI::getValueLoadRef(PatternRewriter &rewriter, + ModuleOp module) { + MLIRContext *ctx = module->getContext(); + SmallVector args{ValueType::get(ctx)}; + SmallVector rets{IREE::Util::ObjectType::get(ctx)}; + auto functionType = FunctionType::get(ctx, args, rets); + + return addDecl(rewriter, module, StringAttr::get(ctx, "async.value.load.ref"), + functionType); +} + +//===----------------------------------------------------------------------===// +// Base class for all Async op conversions +//===----------------------------------------------------------------------===// + +template +struct AsyncOpConversionPattern : public OpConversionPattern { + AsyncOpConversionPattern(TypeConverter &typeConverter, MLIRContext *ctx, + std::shared_ptr api) + : OpConversionPattern(typeConverter, ctx), api(std::move(api)) {} + + std::shared_ptr api; +}; + +//===----------------------------------------------------------------------===// +// Lowering for `async.await` with a token operand. +//===----------------------------------------------------------------------===// + +struct ConvertTokenAwaitOp : public AsyncOpConversionPattern { + using AsyncOpConversionPattern::AsyncOpConversionPattern; + + LogicalResult matchAndRewrite( + AwaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getOperand().getType())) { + return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); + } + ModuleOp module = op->getParentOfType(); + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + auto awaitFuncOp = api->getValueAwait(rewriter, module); + b.create(awaitFuncOp.getSymName(), TypeRange{}, + adaptor.getOperands()); + auto queryFuncOp = api->getValueQuery(rewriter, module); + auto queryOp = b.create(queryFuncOp.getSymName(), + queryFuncOp.getResultTypes(), + adaptor.getOperands()); + b.create(queryOp.getResult(0), + "failed to wait on async token"); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Lowering for `async.await` with a async scalar value operand. +//===----------------------------------------------------------------------===// + +struct ConvertScalarAwaitOp : public AsyncOpConversionPattern { + using AsyncOpConversionPattern::AsyncOpConversionPattern; + + LogicalResult matchAndRewrite( + AwaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getOperand().getType())) { + return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); + } + ModuleOp module = op->getParentOfType(); + auto resultType = op.getResultType(); + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + if (resultType->isInteger(32)) { + auto awaitFuncOp = api->getValueAwait(rewriter, module); + b.create(awaitFuncOp.getSymName(), TypeRange{}, + adaptor.getOperands()); + auto queryFuncOp = api->getValueQuery(rewriter, module); + auto queryOp = b.create(queryFuncOp.getSymName(), + queryFuncOp.getResultTypes(), + adaptor.getOperands()); + b.create(queryOp.getResult(0), + "failed to wait on async value"); + auto loadFuncOp = api->getValueLoadI32(rewriter, module); + rewriter.replaceOpWithNewOp( + op, loadFuncOp.getSymName(), *resultType, adaptor.getOperands()); + } else { + return rewriter.notifyMatchFailure(op, + "unsupported awaitable scalar type"); + } + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Lowering for `async.await` with a async value of custom type operand. +//===----------------------------------------------------------------------===// + +struct ConvertObjectAwaitOp : public AsyncOpConversionPattern { + using AsyncOpConversionPattern::AsyncOpConversionPattern; + + LogicalResult matchAndRewrite( + AwaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getOperand().getType())) { + return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); + } + auto resultType = op.getResultType(); + if (!resultType || api->isScalarType(*resultType)) { + return rewriter.notifyMatchFailure(op, "unsupported async value type"); + } + ModuleOp module = op->getParentOfType(); + MLIRContext *ctx = rewriter.getContext(); + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + auto awaitFuncOp = api->getValueAwait(rewriter, module); + b.create(awaitFuncOp.getSymName(), TypeRange{}, + adaptor.getOperands()); + auto queryFuncOp = api->getValueQuery(rewriter, module); + auto queryOp = b.create(queryFuncOp.getSymName(), + queryFuncOp.getResultTypes(), + adaptor.getOperands()); + b.create(queryOp.getResult(0), + "failed to wait on async value"); + auto loadFuncOp = api->getValueLoadRef(rewriter, module); + auto callOp = b.create(loadFuncOp.getSymName(), + IREE::Util::ObjectType::get(ctx), + adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, op.getResultTypes(), + callOp.getResult(0)); + return success(); + } +}; +} // namespace + +void populateAsyncToRuntimePatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + auto api = std::make_shared(); + + patterns.insert(typeConverter, ctx, api); + patterns.insert(typeConverter, ctx, api); + patterns.insert(typeConverter, ctx, api); +} + +} // namespace openxla::compiler::async diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.h b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.h new file mode 100644 index 0000000..4e48c8e --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.h @@ -0,0 +1,23 @@ +// Copyright 2023 The OpenXLA Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef OPENXLA_COMPILER_DIALECT_ASYNC_CONVERSION_CONVERT_ASYNC_TO_RUNTIME_H_ +#define OPENXLA_COMPILER_DIALECT_ASYNC_CONVERSION_CONVERT_ASYNC_TO_RUNTIME_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace openxla::compiler::async { + +// Appends Async dialect to async runtime patterns to the given pattern list. +// Conversion patterns lower from Async dialect operations to function calls +// corresponding to the async runtime (implemented as a custom VM module). +void populateAsyncToRuntimePatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); + +} // namespace openxla::compiler::async + +#endif // OPENXLA_COMPILER_DIALECT_ASYNC_CONVERSION_CONVERT_ASYNC_TO_RUNTIME_H_ diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.cpp b/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.cpp new file mode 100644 index 0000000..bbf4614 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.cpp @@ -0,0 +1,116 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "openxla/compiler/async/Dialect/Async/IR/Async.h" + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/IRMapping.h" + +using namespace mlir; +using namespace openxla::compiler::async; + +#include "openxla/compiler/async/Dialect/Async/IR/AsyncDialect.cpp.inc" + +void AsyncDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "openxla/compiler/async/Dialect/Async/IR/AsyncOps.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "openxla/compiler/async/Dialect/Async/IR/AsyncTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +/// AwaitOp +//===----------------------------------------------------------------------===// + +void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, + ArrayRef attrs) { + result.addOperands({operand}); + result.attributes.append(attrs.begin(), attrs.end()); + + // Add unwrapped async.value type to the returned values types. + if (auto valueType = operand.getType().dyn_cast()) + result.addTypes(valueType.getValueType()); +} + +static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, + Type &resultType) { + if (parser.parseType(operandType)) return failure(); + + // Add unwrapped async.value type to the returned values types. + if (auto valueType = operandType.dyn_cast()) + resultType = valueType.getValueType(); + + return success(); +} + +static void printAwaitResultType(OpAsmPrinter &p, Operation *op, + Type operandType, Type resultType) { + p << operandType; +} + +LogicalResult AwaitOp::verify() { + Type argType = getOperand().getType(); + + // Awaiting on a token does not have any results. + if (argType.isa() && !getResultTypes().empty()) + return emitOpError("awaiting on a token must have empty result"); + + // Awaiting on a value unwraps the async value type. + if (auto value = argType.dyn_cast()) { + if (*getResultType() != value.getValueType()) + return emitOpError() << "result type " << *getResultType() + << " does not match async value type " + << value.getValueType(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "openxla/compiler/async/Dialect/Async/IR/AsyncOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// TableGen'd type method definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "openxla/compiler/async/Dialect/Async/IR/AsyncTypes.cpp.inc" + +void ValueType::print(AsmPrinter &printer) const { + // Opaque async value type (`!async.value`) + if (!getValueType()) return; + + printer << "<"; + printer.printType(getValueType()); + printer << '>'; +} + +Type ValueType::parse(mlir::AsmParser &parser) { + // Opaque async value type (`async.value`) + if (failed(parser.parseOptionalLess())) { + return ValueType::get(parser.getContext()); + } + + Type ty; + if (parser.parseType(ty) || parser.parseGreater()) { + parser.emitError(parser.getNameLoc(), "failed to parse async value type"); + return Type(); + } + return ValueType::get(ty); +} + +bool ValueType::isOpaque() { return !getValueType(); } diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.h b/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.h new file mode 100644 index 0000000..b4be23b --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/Async.h @@ -0,0 +1,45 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef ASYNC_ASYNC_H_ +#define ASYNC_ASYNC_H_ + +#include "iree/compiler/Dialect/Util/IR/UtilTraits.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +//===----------------------------------------------------------------------===// +// Async Dialect Types +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "openxla/compiler/async/Dialect/Async/IR/AsyncTypes.h.inc" + +//===----------------------------------------------------------------------===// +// Async Dialect +//===----------------------------------------------------------------------===// + +#include "openxla/compiler/async/Dialect/Async/IR/AsyncDialect.h.inc" + +//===----------------------------------------------------------------------===// +// Async Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "openxla/compiler/async/Dialect/Async/IR/AsyncOps.h.inc" + +#endif // ASYNC_ASYNC_H_ diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncDialect.td b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncDialect.td new file mode 100644 index 0000000..448fb17 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncDialect.td @@ -0,0 +1,24 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef ASYNC_DIALECT +#define ASYNC_DIALECT + +include "mlir/IR/OpBase.td" + +def AsyncDialect : Dialect { + let name = "async"; + let cppNamespace = "::openxla::compiler::async"; + + let summary = "Types and operations for async dialect"; + let description = [{ + This dialect contains operations for modeling asynchronous execution. + }]; + + let useDefaultTypePrinterParser = 1; +} + +#endif // ASYNC_DIALECT diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncOps.td b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncOps.td new file mode 100644 index 0000000..175c942 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncOps.td @@ -0,0 +1,74 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef ASYNC_OPS +#define ASYNC_OPS + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/FunctionInterfaces.td" +include "mlir/IR/OpAsmInterface.td" + +include "iree/compiler/Dialect/Util/IR/UtilAttrs.td" + +include "openxla/compiler/async/Dialect/Async/IR/AsyncDialect.td" +include "openxla/compiler/async/Dialect/Async/IR/AsyncTypes.td" + +//===----------------------------------------------------------------------===// +// Async op definitions +//===----------------------------------------------------------------------===// + +// Base class for the operation in this dialect +class Async_Op traits = [Util_YieldPoint]> : + Op; + +def Async_AwaitOp : Async_Op<"await"> { + let summary = "waits for the argument to become ready"; + let description = [{ + The `async.await` operation yields the caller until its argument + becomes ready, and for the `async.value` arguments it unwraps + the underlying value + + Example: + + ```mlir + %0 = ... : !async.token + async.await %0 : !async.token + + %1 = ... : !async.value + %2 = async.await %1 : !async.value + ``` + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let results = (outs Optional:$result); + + let skipDefaultBuilders = 1; + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "mlir::Value":$operand, + CArg<"mlir::ArrayRef", "{}">:$attrs)>, + ]; + + let extraClassDeclaration = [{ + std::optional getResultType() { + if (getResultTypes().empty()) return std::nullopt; + return getResultTypes()[0]; + } + }]; + + let assemblyFormat = [{ + $operand `:` custom( + type($operand), type($result) + ) attr-dict + }]; +} + +#endif // ASYNC_OPS diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncTypes.td b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncTypes.td new file mode 100644 index 0000000..e70bd2a --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/AsyncTypes.td @@ -0,0 +1,64 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +#ifndef ASYNC_TYPES +#define ASYNC_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "openxla/compiler/async/Dialect/Async/IR/AsyncDialect.td" + +//===----------------------------------------------------------------------===// +// Async Types +//===----------------------------------------------------------------------===// + +class Async_Type : TypeDef { + let mnemonic = typeMnemonic; +} + +def Async_TokenType : Async_Type<"Token", "token"> { + let summary = "async token type"; + let description = [{ + `async.token` is a type returned by asynchronous operations, and it becomes + `available` when the asynchronous operations that created it is completed. + }]; +} + +def Async_ValueType : Async_Type<"Value", "value"> { + let summary = "async value type"; + let description = [{ + `async.value` represents a value returned by asynchronous operations, + which may or may not be available currently, but will be available at some + point in the future. + + `valueType` can be omitted from the type when lowering to runtime function + calls (just a `!async.value`). At runtime, valueType becomes a property of + reference counted runtime values. + }]; + + let parameters = (ins "Type":$valueType); + let builders = [ + TypeBuilder<(ins), [{ + return $_get($_ctxt, Type()); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$valueType), [{ + return $_get(valueType.getContext(), valueType); + }]> + ]; + + let extraClassDeclaration = [{ + bool isOpaque(); + }]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +def Async_AnyValueOrTokenType : AnyTypeOf<[Async_ValueType, + Async_TokenType]>; + +#endif // ASYNC_TYPES diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/IR/BUILD.bazel new file mode 100644 index 0000000..9d3da4b --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/BUILD.bazel @@ -0,0 +1,106 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("@iree_core//build_tools/bazel:build_defs.oss.bzl", "iree_cc_library", "iree_gentbl_cc_library", "iree_tablegen_doc", "iree_td_library") +load("@iree_core//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_td_library( + name = "td_files", + srcs = enforce_glob( + [ + "AsyncDialect.td", + "AsyncOps.td", + "AsyncTypes.td", + ], + include = ["*.td"], + ), + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +cc_library( + name = "IR", + srcs = [ + "Async.cpp", + ], + hdrs = [ + "Async.h", + ], + textual_hdrs = [ + "AsyncOps.h.inc", + "AsyncOps.cpp.inc", + "AsyncDialect.cpp.inc", + "AsyncDialect.h.inc", + "AsyncTypes.cpp.inc", + "AsyncTypes.h.inc", + ], + deps = [ + ":AsyncOpsGen", + ":AsyncTypesGen", + "//compiler/src/openxla/compiler/async:defs", + "@iree_core//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:FuncDialect", + ], +) + +iree_gentbl_cc_library( + name = "AsyncOpsGen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=async", + ], + "AsyncDialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=async", + ], + "AsyncDialect.cpp.inc", + ), + ( + ["--gen-op-decls"], + "AsyncOps.h.inc", + ), + ( + ["--gen-op-defs"], + "AsyncOps.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "AsyncOps.td", + deps = [":td_files"], +) + +iree_gentbl_cc_library( + name = "AsyncTypesGen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "AsyncTypes.h.inc", + ), + ( + ["--gen-typedef-defs"], + "AsyncTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "AsyncTypes.td", + deps = [":td_files"], +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/IR/CMakeLists.txt new file mode 100644 index 0000000..d600d4f --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/CMakeLists.txt @@ -0,0 +1,61 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/IR/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IR + HDRS + "Async.h" + TEXTUAL_HDRS + "AsyncDialect.cpp.inc" + "AsyncDialect.h.inc" + "AsyncOps.cpp.inc" + "AsyncOps.h.inc" + "AsyncTypes.cpp.inc" + "AsyncTypes.h.inc" + SRCS + "Async.cpp" + DEPS + ::AsyncOpsGen + ::AsyncTypesGen + LLVMSupport + MLIRFuncDialect + MLIRIR + MLIRSupport + iree::compiler::Dialect::Util::IR + openxla::compiler::async::defs + PUBLIC +) + +iree_tablegen_library( + NAME + AsyncOpsGen + TD_FILE + "AsyncOps.td" + OUTS + --gen-dialect-decls --dialect=async AsyncDialect.h.inc + --gen-dialect-defs --dialect=async AsyncDialect.cpp.inc + --gen-op-decls AsyncOps.h.inc + --gen-op-defs AsyncOps.cpp.inc +) + +iree_tablegen_library( + NAME + AsyncTypesGen + TD_FILE + "AsyncTypes.td" + OUTS + --gen-typedef-decls AsyncTypes.h.inc + --gen-typedef-defs AsyncTypes.cpp.inc +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/BUILD.bazel new file mode 100644 index 0000000..43773ce --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/BUILD.bazel @@ -0,0 +1,29 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "async_ops.mlir", + "async_verify.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "@iree_core//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/CMakeLists.txt new file mode 100644 index 0000000..aea9599 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/CMakeLists.txt @@ -0,0 +1,24 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/IR/test/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "async_ops.mlir" + "async_verify.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_ops.mlir b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_ops.mlir new file mode 100644 index 0000000..ca2622f --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_ops.mlir @@ -0,0 +1,34 @@ +// RUN: iree-opt --iree-plugin=openxla-async --split-input-file %s \ +// RUN: | FileCheck %s + +// CHECK-LABEL: @identity_token +func.func @identity_token(%arg0: !async.token) -> !async.token { + // CHECK: return %arg0 : !async.token + return %arg0 : !async.token +} + +// ----- + +// CHECK-LABEL: @identity_value +func.func @identity_value(%arg0 : !async.value) -> !async.value { + // CHECK: return %arg0 : !async.value + return %arg0 : !async.value +} + +// ----- + +// CHECK-LABEL: @await_token +func.func @await_token(%arg0: !async.token) { + // CHECK: async.await %arg0 + async.await %arg0 : !async.token + return +} + +// ----- + +// CHECK-LABEL: @await_value +func.func @await_value(%arg0: !async.value) -> f32 { + // CHECK: async.await %arg0 + %0 = async.await %arg0 : !async.value + return %0 : f32 +} diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_verify.mlir b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_verify.mlir new file mode 100644 index 0000000..eef171d --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/IR/test/async_verify.mlir @@ -0,0 +1,21 @@ +// RUN: iree-opt %s --iree-plugin=openxla-async -split-input-file -verify-diagnostics | FileCheck %s + +// FileCheck test must have at least one CHECK statement. +// CHECK-LABEL: @no_op +func.func @no_op(%arg0: !async.token) { + return +} + +// ----- + +func.func @wrong_async_await_arg_type(%arg0: f32) { + // expected-error @+1 {{'async.await' op operand #0 must be async value type or async token type, but got 'f32'}} + async.await %arg0 : f32 +} + +// ----- + +func.func @wrong_async_await_result_type(%arg0: !async.value) { + // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} + %0 = "async.await"(%arg0): (!async.value) -> f64 +} diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp new file mode 100644 index 0000000..8f19406 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -0,0 +1,75 @@ +// Copyright 2023 The OpenXLA Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "openxla/compiler/async/Dialect/Async/Conversion/ConvertAsyncToRuntime.h" +#include "openxla/compiler/async/Dialect/Async/IR/Async.h" +#include "openxla/compiler/async/Dialect/Async/Transforms/Passes.h" + +#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME +#include "openxla/compiler/async/Dialect/Async/Transforms/Passes.h.inc" + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace openxla::compiler::async { + +namespace { + +class AsyncToAsyncRuntimePass + : public ::impl::AsyncToAsyncRuntimeBase { + public: + AsyncToAsyncRuntimePass() = default; + void runOnOperation() override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +void AsyncToAsyncRuntimePass::runOnOperation() { + if (getOperation().getBody()->empty()) return; + + auto *context = &getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TokenType token) { return ValueType::get(token.getContext()); }); + typeConverter.addConversion( + [](ValueType value) { return ValueType::get(value.getContext()); }); + + // Ensure all async dialect operations go away. + ConversionTarget conversionTarget(*context); + conversionTarget.addIllegalDialect(); + conversionTarget + .addLegalDialect(); + conversionTarget.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); + + RewritePatternSet patterns(context); + populateAsyncToRuntimePatterns(typeConverter, patterns); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + + if (failed(applyPartialConversion(getOperation(), conversionTarget, + std::move(patterns)))) { + getOperation().emitError() << "conversion from async to runtime failed"; + return signalPassFailure(); + } +} + +std::unique_ptr> createAsyncToAsyncRuntimePass() { + return std::make_unique(); +} + +} // namespace openxla::compiler::async diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/BUILD.bazel new file mode 100644 index 0000000..f50df33 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/BUILD.bazel @@ -0,0 +1,65 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load( + "//build_tools/bazel:build_defs.oss.bzl", + "iree_gentbl_cc_library", + "iree_td_library", +) + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_td_library( + name = "td_files", + srcs = [ + "Passes.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +cc_library( + name = "Transforms", + srcs = [ + "AsyncToAsyncRuntime.cpp", + ], + hdrs = [ + "Passes.h", + "Passes.h.inc", + ], + deps = [ + ":PassesIncGen", + "//compiler/src/openxla/compiler/async:defs", + "//compiler/src/openxla/compiler/async/Dialect/Async/IR", + "//compiler/src/openxla/compiler/async/Dialect/Async/Conversion:AsyncToRuntime", + "@iree_core//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformDialect", + ], +) + +iree_gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["--gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + ":td_files", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/CMakeLists.txt new file mode 100644 index 0000000..bce0e4e --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/CMakeLists.txt @@ -0,0 +1,43 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/Transforms/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + Transforms + HDRS + "Passes.h" + "Passes.h.inc" + SRCS + "AsyncToAsyncRuntime.cpp" + DEPS + ::PassesIncGen + LLVMSupport + MLIRIR + MLIRPass + MLIRTransformDialect + iree::compiler::Dialect::Util::IR + openxla::compiler::async::Dialect::Async::Conversion::AsyncToRuntime + openxla::compiler::async::Dialect::Async::IR + openxla::compiler::async::defs + PUBLIC +) + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.h b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.h new file mode 100644 index 0000000..e278e60 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.h @@ -0,0 +1,25 @@ +// Copyright 2023 The OpenXLA Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef OPENXLA_ASYNC_TRANSFORMS_PASSES_H_ +#define OPENXLA_ASYNC_TRANSFORMS_PASSES_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +namespace openxla::compiler::async { +std::unique_ptr> +createAsyncToAsyncRuntimePass(); +} // namespace openxla::compiler::async + +#endif // OPENXLA_ASYNC_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.td b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.td new file mode 100644 index 0000000..b64330a --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/Passes.td @@ -0,0 +1,20 @@ +// Copyright 2023 The OpenXLA Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef OPENXLA_ASYNC_TRANSFORMS_PASSES_TD_ +#define OPENXLA_ASYNC_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "mlir::ModuleOp"> { + let summary = "Lower all high level async operations to" + "the explicit async.runtime operations"; + let constructor = "openxla::compiler::async::createAsyncToAsyncRuntimePass()"; + let dependentDialects = ["::openxla::compiler::async::AsyncDialect", "mlir::func::FuncDialect", + "mlir::iree_compiler::IREE::Util::UtilDialect"]; +} + +#endif // OPENXLA_ASYNC_TRANSFORMS_PASSES_TD_ diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/BUILD.bazel new file mode 100644 index 0000000..7f40815 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/BUILD.bazel @@ -0,0 +1,29 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "async_to_runtime.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + labels = ["hostonly"], + tools = [ + "@iree_core//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/CMakeLists.txt new file mode 100644 index 0000000..d922e8c --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/CMakeLists.txt @@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/BUILD.bazel# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "async_to_runtime.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/async_to_runtime.mlir b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/async_to_runtime.mlir new file mode 100644 index 0000000..40f24b3 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/Async/Transforms/test/async_to_runtime.mlir @@ -0,0 +1,50 @@ +// RUN: iree-opt %s --iree-plugin=openxla-async --split-input-file \ +// RUN: --async-to-async-runtime | \ +// RUN: FileCheck %s + +func.func @await_token(%arg0: !async.token){ + async.await %arg0 : !async.token + return +} + +// CHECK: func.func @await_token(%[[ARG:.*]]: !async.value) { +// CHECK: call @async.value.await(%[[ARG]]) : (!async.value) -> () +// CHECK: %[[R0:.*]] = call @async.value.query(%[[ARG]]) : (!async.value) -> i32 +// CHECK: util.status.check_ok %[[R0:.*]] +// CHECK: return +// CHECK: } +// CHECK: func.func private @async.value.await(!async.value) + +// ----- + +func.func @await_scalar_value(%arg0: !async.value) -> i32 { + %0 = async.await %arg0 : !async.value + return %0 : i32 +} + +// CHECK: func.func @await_scalar_value(%[[ARG:.*]]: !async.value) -> i32 { +// CHECK: call @async.value.await(%[[ARG]]) : (!async.value) -> () +// CHECK: %[[R0:.*]] = call @async.value.query(%[[ARG]]) : (!async.value) -> i32 +// CHECK: util.status.check_ok %[[R0:.*]] +// CHECK: %[[R1:.*]] = call @async.value.load.i32(%[[ARG]]) : (!async.value) -> i32 +// CHECK: return %[[R1]] : i32 +// CHECK: } +// CHECK: func.func private @async.value.await(!async.value) +// CHECK: func.func private @async.value.query(!async.value) -> i32 +// CHECK: func.func private @async.value.load.i32(!async.value) -> i32 + +// ----- + +func.func @await_memref_value(%arg0: !async.value>) -> memref<2xi32> { + %0 = async.await %arg0 : !async.value> + return %0 : memref<2xi32> +} + +// CHECK: func.func @await_memref_value(%[[ARG:.*]]: !async.value) +// CHECK-SAME: -> memref<2xi32> { +// CHECK: call @async.value.await(%[[ARG]]) : (!async.value) -> () +// CHECK: %[[R0:.*]] = call @async.value.load.ref(%[[ARG]]) : (!async.value) -> !util.object +// CHECK: %[[R1:.*]] = util.cast %[[R0]] : !util.object to memref<2xi32> +// CHECK: return %[[R1]] : memref<2xi32> +// CHECK: } +// CHECK: func.func private @async.value.load.ref(!async.value) -> !util.object diff --git a/compiler/src/openxla/compiler/async/Dialect/BUILD.bazel b/compiler/src/openxla/compiler/async/Dialect/BUILD.bazel new file mode 100644 index 0000000..e69de29 diff --git a/compiler/src/openxla/compiler/async/Dialect/CMakeLists.txt b/compiler/src/openxla/compiler/async/Dialect/CMakeLists.txt new file mode 100644 index 0000000..1066f60 --- /dev/null +++ b/compiler/src/openxla/compiler/async/Dialect/CMakeLists.txt @@ -0,0 +1,13 @@ +################################################################################ +# Autogenerated by ../iree/build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/openxla/compiler/async/Dialect/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/openxla/compiler/async/PluginRegistration.cpp b/compiler/src/openxla/compiler/async/PluginRegistration.cpp new file mode 100644 index 0000000..6480c24 --- /dev/null +++ b/compiler/src/openxla/compiler/async/PluginRegistration.cpp @@ -0,0 +1,53 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/PluginAPI/Client.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "openxla/compiler/async/Dialect/Async/IR/Async.h" +#include "openxla/compiler/async/Dialect/Async/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace detail { +namespace { + +#define GEN_PASS_REGISTRATION +#include "openxla/compiler/async/Dialect/Async/Transforms/Passes.h.inc" + +} // namespace +} // namespace detail + +namespace { + +struct AsyncOptions { + void bindOptions(OptionsBinder &binder) {} +}; + +struct AsyncSession : public PluginSession { + static void registerPasses() { ::detail::registerPasses(); } + + void onRegisterDialects(DialectRegistry ®istry) override { + registry.insert(); + } + + void extendPreprocessingPassPipeline(OpPassManager &pm) override { + pm.addPass(openxla::compiler::async::createAsyncToAsyncRuntimePass()); + } +}; + +} // namespace + +IREE_DEFINE_COMPILER_OPTION_FLAGS(AsyncOptions); + +extern "C" bool iree_register_compiler_plugin_openxla_async( + mlir::iree_compiler::PluginRegistrar *registrar) { + registrar->registerPlugin("openxla-async"); + return true; +}