From ac9548469e8e46f8ad6ea29c1c7144a774ebba19 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Wed, 15 Nov 2023 21:17:04 -0800 Subject: [PATCH] Add vector_ext dialect (#15599) This patch introduces the vector_ext dialect. The purpose of this dialect is to have a place for experimenting with things beyond what the upstream vector dialect provides. In this particular PR, two new features are added 1. An explicit IR representation of the high dimensional layout that looks like this #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 2>, 4>, 4> The nesting makes clear what the innermost dimensions are and their corresponding shapes. 2. Adds a layout conflict resolution operator. During layout analysis, this operator can be used to resolve any differences in layout. The lowering of this operator is not provided but the semantics are that given a vector with an existing layout and a desired layout, the operator transforms the vector to the desired layout. --- compiler/src/iree/compiler/Tools/BUILD.bazel | 1 + .../src/iree/compiler/Tools/CMakeLists.txt | 1 + .../iree/compiler/Tools/init_iree_dialects.h | 2 + .../iree-dialects/BUILD.bazel | 81 +++++++++++++++++ .../iree-dialects/Dialect/CMakeLists.txt | 1 + .../Dialect/VectorExt/CMakeLists.txt | 2 + .../Dialect/VectorExt/IR/CMakeLists.txt | 36 ++++++++ .../Dialect/VectorExt/IR/VectorExtBase.td | 72 +++++++++++++++ .../Dialect/VectorExt/IR/VectorExtDialect.h | 17 ++++ .../Dialect/VectorExt/IR/VectorExtOps.h | 30 +++++++ .../Dialect/VectorExt/IR/VectorExtOps.td | 47 ++++++++++ .../iree-dialects/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/VectorExt/CMakeLists.txt | 2 + .../lib/Dialect/VectorExt/IR/CMakeLists.txt | 15 ++++ .../Dialect/VectorExt/IR/VectorExtDialect.cpp | 90 +++++++++++++++++++ .../lib/Dialect/VectorExt/IR/VectorExtOps.cpp | 50 +++++++++++ .../test/Dialect/iree_vector_ext/invalid.mlir | 31 +++++++ .../Dialect/iree_vector_ext/roundtrip.mlir | 22 +++++ .../tools/iree-dialects-opt/CMakeLists.txt | 1 + .../iree-dialects-opt/iree-dialects-opt.cpp | 2 + 20 files changed, 504 insertions(+) create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td create mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt create mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt create mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp create mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp create mode 100644 llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir create mode 100644 llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 297fd7fe52d8..ec3db192b01e 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -66,6 +66,7 @@ iree_compiler_cc_library( "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses", + "//llvm-external-projects/iree-dialects:IREEVectorExtDialect", "@llvm-project//mlir:IR", ], ) diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index 475a06f53b44..d7259334bf3a 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -53,6 +53,7 @@ iree_cc_library( IREELinalgExtTransforms IREELinalgTransformDialect IREELinalgTransformDialectPasses + IREEVectorExtDialect MLIRIR iree::compiler::Bindings::Native::Transforms iree::compiler::Bindings::TFLite::Transforms diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index 7d19e14736d2..b454421fb329 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h @@ -16,6 +16,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" #include "iree-dialects/Dialect/LinalgTransform/Passes.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h" #include "iree/compiler/Codegen/Interfaces/Interfaces.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" @@ -49,6 +50,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { IREE::Util::UtilDialect, IREE::VM::VMDialect, IREE::VMVX::VMVXDialect, + IREE::VectorExt::IREEVectorExtDialect, IREE::Vulkan::VulkanDialect>(); // clang-format on diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index 68c5395c9f03..5cba3b0ba0cf 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel @@ -32,6 +32,7 @@ filegroup( "include/iree-dialects/Dialect/Input/*.td", "include/iree-dialects/Dialect/LinalgExt/IR/*.td", "include/iree-dialects/Dialect/LinalgExt/Passes/*.td", + "include/iree-dialects/Dialect/VectorExt/IR/*.td", ]), ) @@ -42,6 +43,7 @@ td_library( "include/iree-dialects/Dialect/LinalgExt/IR/*.td", "include/iree-dialects/Dialect/LinalgExt/Passes/*.td", "include/iree-dialects/Dialect/LinalgTransform/*.td", + "include/iree-dialects/Dialect/VectorExt/IR/*.td", "python/iree/compiler/dialects/*.td", ]), includes = ["include"], @@ -618,6 +620,84 @@ cc_library( ], ) +################################################################################ +# IREEVectorExt Dialect +################################################################################ + +gentbl_cc_library( + name = "IREEVectorExtIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "--dialect=iree_vector_ext", + "--gen-dialect-decls", + ], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc", + ), + ( + [ + "--dialect=iree_vector_ext", + "--gen-dialect-defs", + ], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc", + ), + ( + ["--gen-attrdef-decls"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc", + ), + ( + ["--gen-op-defs"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc", + ), + ( + ["--gen-typedef-decls"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td", + deps = [ + ":TdFiles", + ], +) + +cc_library( + name = "IREEVectorExtDialect", + srcs = glob([ + "lib/Dialect/VectorExt/IR/*.cpp", + ]), + hdrs = glob([ + "include/iree-dialects/Dialect/VectorExt/IR/*.h", + ]), + includes = ["include"], + deps = [ + ":IREEVectorExtIncGen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + ################################################################################ # CAPI ################################################################################ @@ -681,6 +761,7 @@ cc_binary( ":IREELinalgExtTransformOps", ":IREELinalgTransformDialect", ":IREELinalgTransformDialectPasses", + ":IREEVectorExtDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt index 16d52d437fde..18881bd9e2d4 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Input) add_subdirectory(LinalgExt) add_subdirectory(LinalgTransform) +add_subdirectory(VectorExt) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt new file mode 100644 index 000000000000..9ba3d84620ba --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) + diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt new file mode 100644 index 000000000000..0b29b25dda5e --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt @@ -0,0 +1,36 @@ +function(_add_dialect) + set(LLVM_TARGET_DEFINITIONS VectorExtOps.td) + mlir_tablegen(VectorExtAttrs.h.inc -gen-attrdef-decls) + mlir_tablegen(VectorExtAttrs.cpp.inc -gen-attrdef-defs) + mlir_tablegen(VectorExtEnums.h.inc -gen-enum-decls) + mlir_tablegen(VectorExtEnums.cpp.inc -gen-enum-defs) + mlir_tablegen(VectorExtOps.h.inc -gen-op-decls) + mlir_tablegen(VectorExtOps.cpp.inc -gen-op-defs) + mlir_tablegen(VectorExtTypes.h.inc -gen-typedef-decls) + mlir_tablegen(VectorExtTypes.cpp.inc -gen-typedef-defs) + mlir_tablegen(VectorExtDialect.h.inc --gen-dialect-decls --dialect=iree_vector_ext) + mlir_tablegen(VectorExtDialect.cpp.inc --gen-dialect-defs --dialect=iree_vector_ext) + add_public_tablegen_target(IREEVectorExtIncGen) + add_dependencies(mlir-headers IREEVectorExtIncGen) +endfunction() + +function(_add_doc) + set(LLVM_TARGET_DEFINITIONS VectorExtOps.td) + set(_FLAGS + "--strip-prefix=::mlir::iree_compiler::IREE::" + ) + mlir_tablegen(VectorExtOps.md -gen-dialect-doc ${_FLAGS}) + set(GEN_DOC_FILE ${IREE_DIALECTS_BINARY_DIR}/docs/Dialects/VectorExtOps.md) + add_custom_command( + OUTPUT ${GEN_DOC_FILE} + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md + ${GEN_DOC_FILE} + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md) + add_custom_target(VectorExtOpsDocGen DEPENDS ${GEN_DOC_FILE}) + add_dependencies(iree-dialects-doc VectorExtOpsDocGen) +endfunction() + +_add_dialect() +_add_doc() + diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td new file mode 100644 index 000000000000..1de6d817827f --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td @@ -0,0 +1,72 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_VECTOREXT_BASE +#define IREE_DIALECT_VECTOREXT_BASE + +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" + +//===----------------------------------------------------------------------===// +// Dialect definition +//===----------------------------------------------------------------------===// + +def IREEVectorExt_Dialect : Dialect { + let name = "iree_vector_ext"; + let cppNamespace = "::mlir::iree_compiler::IREE::VectorExt"; + let summary = [{ + IREE Vector Extensions. + }]; + let description = [{ + A dialect designed for experimenting with vector operations + beyond what is currently available in the Vector Dialect. + }]; + let useDefaultAttributePrinterParser = 1; +} + +//===---------------------------------------------------------------------===// +// Vector layout attributes +//===---------------------------------------------------------------------===// + +class IREEVectorExt_Attr traits = []> + : AttrDef; + +def PerDimLayoutAttr : IREEVectorExt_Attr<"PerDimLayout"> { + let mnemonic = "per_dim_layout"; + let summary = [{high-dimensional vector register layout for a given vector dimension}]; + let description = [{ + This attribute describes the per dimension register layout for a given vector + that could be prescribed by an operator such as matrix multiplication. + This is a way to explicitly represent the layout in the IR + when it is in the SIMD form prior to converting to the SIMT form so that + we can reason about layouts, propagating layouts and layout conflicts. + }]; + let parameters = (ins + ArrayRefParameter<"std::string", "labels for the high dimensional layout dims">:$labels, + ArrayRefParameter<"int64_t", "shapes for the high dimensional layout dims">:$shapes + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 0; +} + +def LayoutAttr : IREEVectorExt_Attr<"Layout"> { + let mnemonic = "layout"; + let summary = [{high-dimensional vector register layout for a given vector}]; + let description = [{ + This contains a complete specification of the layout for a given vector, + whereas the attribute above only specifies the per dimension layout. + }]; + let parameters = (ins + ArrayRefParameter<"PerDimLayoutAttr", "layout for each dimension of the vector">:$layouts + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 0; +} + +#endif // IREE_DIALECT_VECTOREXT_BASE + diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h new file mode 100644 index 000000000000..82bdccc049d1 --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h @@ -0,0 +1,17 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_ +#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +// clang-format off: must be included after all LLVM/MLIR headers +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc" // IWYU pragma: keep +// clang-format on + +#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_ diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h new file mode 100644 index 000000000000..b1f4b6f46a6a --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h @@ -0,0 +1,30 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_ +#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_ + +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +// clang-format off + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc" // IWYU pragma: export + +#define GET_ATTRDEF_CLASSES +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc" // IWYU pragma: export + +#define GET_OP_CLASSES +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc" // IWYU pragma: export + +// clang-format on + +#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_ diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td new file mode 100644 index 000000000000..77476d07c7b0 --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td @@ -0,0 +1,47 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_VECTOREXT_OPS +#define IREE_DIALECT_VECTOREXT_OPS + +include "iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td" + +//===----------------------------------------------------------------------===// +// Base class. +//===----------------------------------------------------------------------===// + +class IREEVectorExt_PureOp traits = []> : + Op { +} + +//===----------------------------------------------------------------------===// +// Layout ops. +//===----------------------------------------------------------------------===// + +def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> { + let summary = "Layout Conflict Resolution operator"; + let description = [{ + The layout conflict resolution operator takes a vector and a + desired layout and transforms the vector to one with the + desired layout. + }]; + let arguments = (ins + AnyVector:$input, + LayoutAttr:$sourceLayout, + LayoutAttr:$desiredLayout + ); + let results = (outs + AnyVector:$output + ); + let extraClassDeclaration = [{ + + }]; + let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; + let hasVerifier = 1; +} + +#endif // IREE_DIALECT_VECTOREXT_OPS + diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt index 16d52d437fde..18881bd9e2d4 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Input) add_subdirectory(LinalgExt) add_subdirectory(LinalgTransform) +add_subdirectory(VectorExt) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt new file mode 100644 index 000000000000..9ba3d84620ba --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) + diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt new file mode 100644 index 000000000000..8f8f9bb0fbf7 --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_library(IREEVectorExtDialect + VectorExtDialect.cpp + VectorExtOps.cpp + + ADDITIONAL_HEADER_DIRS + ${IREE_DIALECTS_SOURCE_DIR}/include + + DEPENDS + IREEVectorExtIncGen + + LINK_LIBS PUBLIC + MLIRIR +) + +iree_dialects_target_includes(IREEVectorExtDialect) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp new file mode 100644 index 000000000000..06740ee54e31 --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp @@ -0,0 +1,90 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::VectorExt; + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc" // IWYU pragma: keep + +#define GET_ATTRDEF_CLASSES +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" // IWYU pragma: keep + +void IREEVectorExtDialect::initialize() { + + addAttributes< +#define GET_ATTRDEF_LIST +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" + >(); + +#define GET_OP_LIST + addOperations< +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" + >(); +} + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc" + +// Parses an attribute with syntax +// <"BatchX"<"VecX", 2>, 4> +Attribute PerDimLayoutAttr::parse(AsmParser &parser, Type type) { + SmallVector dimNames; + SmallVector dimShapes; + std::string name; + while (!(parser.parseOptionalLess() || parser.parseOptionalString(&name))) { + dimNames.push_back(name); + } + int64_t dim; + while (!(parser.parseOptionalComma() || parser.parseInteger(dim) || + parser.parseGreater())) { + dimShapes.push_back(dim); + } + std::reverse(dimShapes.begin(), dimShapes.end()); + return PerDimLayoutAttr::get(parser.getContext(), dimNames, dimShapes); +} + +void PerDimLayoutAttr::print(AsmPrinter &printer) const { + for (auto label : getLabels()) + printer << "<" << label; + for (auto shape : llvm::reverse(getShapes())) + printer << ", " << shape << ">"; +} + +// Parses an attribute with syntax +// #layout<<"BatchX"<"VecX", 2>, 4>, <"BatchY"<"VecZ", 4>,2>>> +Attribute LayoutAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess()) + return {}; + SmallVector layout; + PerDimLayoutAttr perDimLayout; + while (!(parser.parseAttribute(perDimLayout, type))) { + layout.push_back(perDimLayout); + if (parser.parseOptionalComma()) + break; + } + if ((parser.parseGreater())) + return {}; + return LayoutAttr::get(parser.getContext(), layout); +} + +static void printArray(AsmPrinter &printer, + ArrayRef layouts) { + printer << "<"; + for (auto layout : llvm::enumerate(layouts)) { + printer << layout.value(); + if (layout.index() < layouts.size() - 1) + printer << ", "; + } + printer << ">"; +} + +void LayoutAttr::print(AsmPrinter &printer) const { + printArray(printer, getLayouts()); +} diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp new file mode 100644 index 000000000000..2d453451c9c4 --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp @@ -0,0 +1,50 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" +#include + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::VectorExt; +namespace IREE = mlir::iree_compiler::IREE; + +//===----------------------------------------------------------------------===// +// LayoutConflictResolutionOp +//===----------------------------------------------------------------------===// + +LogicalResult validateLayout(Operation *op, StringRef label, LayoutAttr layout, + ArrayRef inputShape) { + for (auto perDimLayout : llvm::enumerate(layout.getLayouts())) { + ArrayRef shape = perDimLayout.value().getShapes(); + int64_t computedShape = + std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); + int64_t expectedShape = inputShape[perDimLayout.index()]; + if (computedShape != expectedShape) { + return op->emitError("The " + label + + " layout shape does not match the input shape. " + "Expected shape to be ") + << std::to_string(expectedShape) << ", got " + << std::to_string(computedShape); + } + } + return success(); +} + +// Validate that the desired layout has the same shape as the input. +LogicalResult LayoutConflictResolutionOp::verify() { + Operation *op = getOperation(); + ArrayRef inputShape = + cast(getInput().getType()).getShape(); + if (succeeded(validateLayout(op, "source", getSourceLayout(), inputShape))) + return validateLayout(op, "desired", getDesiredLayout(), inputShape); + return failure(); +} + +// clang-format off +#define GET_OP_CLASSES +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep +// clang-format: on diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir new file mode 100644 index 000000000000..79cf6694772f --- /dev/null +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir @@ -0,0 +1,31 @@ +// RUN: iree-dialects-opt --split-input-file --verify-diagnostics %s + +#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 1>, 1>, 1> +#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4> +#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> +#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1> +func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { + %cst_0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + // expected-error @+1 {{The desired layout shape does not match the input shape. Expected shape to be 32, got 1}} + %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16> + return %2 : vector<32x32xf16> +} + +// ----- + +#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 1>, 1>, 1> +#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4> +#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> +#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1> +func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { + %cst_0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + // expected-error @+1 {{The source layout shape does not match the input shape. Expected shape to be 32, got 1}} + %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16> + return %2 : vector<32x32xf16> +} + +// ----- diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir new file mode 100644 index 000000000000..63a0f0ba1b92 --- /dev/null +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir @@ -0,0 +1,22 @@ +// RUN: iree-dialects-opt --split-input-file %s | FileCheck %s + +#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 2>, 4>, 4> +#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4> +#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> +#layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1> +func.func @specify_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { + %cst_0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16> + return %2 : vector<32x32xf16> +} + +// CHECK-LABEL: func.func @specify_layout +// CHECK: iree_vector_ext.layout_conflict_resolution +// CHECK: desiredLayout = #iree_vector_ext.layout<#iree_vector_ext.per_dim_layout, 2>, 4>, +// CHECK-SAME: #iree_vector_ext.per_dim_layout, 4>, 4>> +// CHECK: sourceLayout = #iree_vector_ext.layout<#iree_vector_ext.per_dim_layout, 4>, 4>, +// CHECK-SAME: #iree_vector_ext.per_dim_layout, 2>, 4>> + +// ----- diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt index 548b60e9273f..5789bc5bef63 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt @@ -8,6 +8,7 @@ set(LIBS IREELinalgTransformDialect IREELinalgTransformDialectPasses IREETransformsTestPasses + IREEVectorExtDialect # Core dialects. MLIRAffineDialect MLIRArithDialect diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp index b7e73f36fb8b..085c31eca865 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp @@ -10,6 +10,7 @@ #include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" #include "iree-dialects/Dialect/LinalgTransform/Passes.h" #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -59,6 +60,7 @@ int main(int argc, char **argv) { // Local dialects mlir::iree_compiler::IREE::Input::IREEInputDialect, mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect, + mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect, // Upstream dialects mlir::async::AsyncDialect, mlir::arith::ArithDialect,