Skip to content

Commit

Permalink
Add vector_ext dialect (iree-org#15599)
Browse files Browse the repository at this point in the history
This patch introduces the vector_ext dialect. The purpose of this
dialect is to have a place for experimenting with things beyond what the
upstream vector dialect provides.

In this particular PR, two new features are added
1. An explicit IR representation of the high dimensional layout that
looks like this #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY",
2>, 4>, 4> The nesting makes clear what the innermost dimensions are and
their corresponding shapes.

2. Adds a layout conflict resolution operator. During layout analysis,
this operator can be used to resolve any differences in layout. The
lowering of this operator is not provided but the semantics are that
given a vector with an existing layout and a desired layout, the
operator transforms the vector to the desired layout.
  • Loading branch information
harsh-nod authored Nov 16, 2023
1 parent 561841c commit ac95484
Show file tree
Hide file tree
Showing 20 changed files with 504 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ iree_compiler_cc_library(
"//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
"//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//mlir:IR",
],
)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ iree_cc_library(
IREELinalgExtTransforms
IREELinalgTransformDialect
IREELinalgTransformDialectPasses
IREEVectorExtDialect
MLIRIR
iree::compiler::Bindings::Native::Transforms
iree::compiler::Bindings::TFLite::Transforms
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Tools/init_iree_dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Interfaces/Interfaces.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
Expand Down Expand Up @@ -49,6 +50,7 @@ inline void registerIreeDialects(DialectRegistry &registry) {
IREE::Util::UtilDialect,
IREE::VM::VMDialect,
IREE::VMVX::VMVXDialect,
IREE::VectorExt::IREEVectorExtDialect,
IREE::Vulkan::VulkanDialect>();
// clang-format on

Expand Down
81 changes: 81 additions & 0 deletions llvm-external-projects/iree-dialects/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ filegroup(
"include/iree-dialects/Dialect/Input/*.td",
"include/iree-dialects/Dialect/LinalgExt/IR/*.td",
"include/iree-dialects/Dialect/LinalgExt/Passes/*.td",
"include/iree-dialects/Dialect/VectorExt/IR/*.td",
]),
)

Expand All @@ -42,6 +43,7 @@ td_library(
"include/iree-dialects/Dialect/LinalgExt/IR/*.td",
"include/iree-dialects/Dialect/LinalgExt/Passes/*.td",
"include/iree-dialects/Dialect/LinalgTransform/*.td",
"include/iree-dialects/Dialect/VectorExt/IR/*.td",
"python/iree/compiler/dialects/*.td",
]),
includes = ["include"],
Expand Down Expand Up @@ -618,6 +620,84 @@ cc_library(
],
)

################################################################################
# IREEVectorExt Dialect
################################################################################

gentbl_cc_library(
name = "IREEVectorExtIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"--dialect=iree_vector_ext",
"--gen-dialect-decls",
],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc",
),
(
[
"--dialect=iree_vector_ext",
"--gen-dialect-defs",
],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc",
),
(
["--gen-attrdef-decls"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc",
),
(
["--gen-attrdef-defs"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc",
),
(
["--gen-enum-decls"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc",
),
(
["--gen-enum-defs"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc",
),
(
["--gen-op-decls"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc",
),
(
["--gen-op-defs"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc",
),
(
["--gen-typedef-decls"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.h.inc",
),
(
["--gen-typedef-defs"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td",
deps = [
":TdFiles",
],
)

cc_library(
name = "IREEVectorExtDialect",
srcs = glob([
"lib/Dialect/VectorExt/IR/*.cpp",
]),
hdrs = glob([
"include/iree-dialects/Dialect/VectorExt/IR/*.h",
]),
includes = ["include"],
deps = [
":IREEVectorExtIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

################################################################################
# CAPI
################################################################################
Expand Down Expand Up @@ -681,6 +761,7 @@ cc_binary(
":IREELinalgExtTransformOps",
":IREELinalgTransformDialect",
":IREELinalgTransformDialectPasses",
":IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
add_subdirectory(LinalgTransform)
add_subdirectory(VectorExt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_subdirectory(IR)

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
function(_add_dialect)
set(LLVM_TARGET_DEFINITIONS VectorExtOps.td)
mlir_tablegen(VectorExtAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorExtAttrs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(VectorExtEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorExtEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorExtOps.h.inc -gen-op-decls)
mlir_tablegen(VectorExtOps.cpp.inc -gen-op-defs)
mlir_tablegen(VectorExtTypes.h.inc -gen-typedef-decls)
mlir_tablegen(VectorExtTypes.cpp.inc -gen-typedef-defs)
mlir_tablegen(VectorExtDialect.h.inc --gen-dialect-decls --dialect=iree_vector_ext)
mlir_tablegen(VectorExtDialect.cpp.inc --gen-dialect-defs --dialect=iree_vector_ext)
add_public_tablegen_target(IREEVectorExtIncGen)
add_dependencies(mlir-headers IREEVectorExtIncGen)
endfunction()

function(_add_doc)
set(LLVM_TARGET_DEFINITIONS VectorExtOps.td)
set(_FLAGS
"--strip-prefix=::mlir::iree_compiler::IREE::"
)
mlir_tablegen(VectorExtOps.md -gen-dialect-doc ${_FLAGS})
set(GEN_DOC_FILE ${IREE_DIALECTS_BINARY_DIR}/docs/Dialects/VectorExtOps.md)
add_custom_command(
OUTPUT ${GEN_DOC_FILE}
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md
${GEN_DOC_FILE}
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md)
add_custom_target(VectorExtOpsDocGen DEPENDS ${GEN_DOC_FILE})
add_dependencies(iree-dialects-doc VectorExtOpsDocGen)
endfunction()

_add_dialect()
_add_doc()

Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECT_VECTOREXT_BASE
#define IREE_DIALECT_VECTOREXT_BASE

include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// Dialect definition
//===----------------------------------------------------------------------===//

def IREEVectorExt_Dialect : Dialect {
let name = "iree_vector_ext";
let cppNamespace = "::mlir::iree_compiler::IREE::VectorExt";
let summary = [{
IREE Vector Extensions.
}];
let description = [{
A dialect designed for experimenting with vector operations
beyond what is currently available in the Vector Dialect.
}];
let useDefaultAttributePrinterParser = 1;
}

//===---------------------------------------------------------------------===//
// Vector layout attributes
//===---------------------------------------------------------------------===//

class IREEVectorExt_Attr<string name, list<Trait> traits = []>
: AttrDef<IREEVectorExt_Dialect, name, traits>;

def PerDimLayoutAttr : IREEVectorExt_Attr<"PerDimLayout"> {
let mnemonic = "per_dim_layout";
let summary = [{high-dimensional vector register layout for a given vector dimension}];
let description = [{
This attribute describes the per dimension register layout for a given vector
that could be prescribed by an operator such as matrix multiplication.
This is a way to explicitly represent the layout in the IR
when it is in the SIMD form prior to converting to the SIMT form so that
we can reason about layouts, propagating layouts and layout conflicts.
}];
let parameters = (ins
ArrayRefParameter<"std::string", "labels for the high dimensional layout dims">:$labels,
ArrayRefParameter<"int64_t", "shapes for the high dimensional layout dims">:$shapes
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 0;
}

def LayoutAttr : IREEVectorExt_Attr<"Layout"> {
let mnemonic = "layout";
let summary = [{high-dimensional vector register layout for a given vector}];
let description = [{
This contains a complete specification of the layout for a given vector,
whereas the attribute above only specifies the per dimension layout.
}];
let parameters = (ins
ArrayRefParameter<"PerDimLayoutAttr", "layout for each dimension of the vector">:$layouts
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 0;
}

#endif // IREE_DIALECT_VECTOREXT_BASE

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_
#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_

#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"

// clang-format off: must be included after all LLVM/MLIR headers
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc" // IWYU pragma: keep
// clang-format on

#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_

#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"

// clang-format off

#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc" // IWYU pragma: export

#define GET_ATTRDEF_CLASSES
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc" // IWYU pragma: export

#define GET_OP_CLASSES
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc" // IWYU pragma: export

// clang-format on

#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECT_VECTOREXT_OPS
#define IREE_DIALECT_VECTOREXT_OPS

include "iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td"

//===----------------------------------------------------------------------===//
// Base class.
//===----------------------------------------------------------------------===//

class IREEVectorExt_PureOp<string mnemonic, list<Trait> traits = []> :
Op<IREEVectorExt_Dialect, mnemonic, traits> {
}

//===----------------------------------------------------------------------===//
// Layout ops.
//===----------------------------------------------------------------------===//

def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> {
let summary = "Layout Conflict Resolution operator";
let description = [{
The layout conflict resolution operator takes a vector and a
desired layout and transforms the vector to one with the
desired layout.
}];
let arguments = (ins
AnyVector:$input,
LayoutAttr:$sourceLayout,
LayoutAttr:$desiredLayout
);
let results = (outs
AnyVector:$output
);
let extraClassDeclaration = [{

}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
let hasVerifier = 1;
}

#endif // IREE_DIALECT_VECTOREXT_OPS

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
add_subdirectory(LinalgTransform)
add_subdirectory(VectorExt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_subdirectory(IR)

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_mlir_library(IREEVectorExtDialect
VectorExtDialect.cpp
VectorExtOps.cpp

ADDITIONAL_HEADER_DIRS
${IREE_DIALECTS_SOURCE_DIR}/include

DEPENDS
IREEVectorExtIncGen

LINK_LIBS PUBLIC
MLIRIR
)

iree_dialects_target_includes(IREEVectorExtDialect)
Loading

0 comments on commit ac95484

Please sign in to comment.