Skip to content

Commit

Permalink
[aievec] Add AIEml intrinsic for bf16/f32 types
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain committed Nov 7, 2023
1 parent ae57cce commit 4dd1644
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 1 deletion.
38 changes: 38 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define AIEVEC_OPS

include "aie/Dialect/AIEVec/IR/AIEVecTypes.td"
include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Base class for AIE dialect ops.
Expand Down Expand Up @@ -808,4 +809,41 @@ def AIEVec_BandOp:
let hasVerifier = 0;
}

def AIEVec_MatMulOp:
AIEVec_Op<"matmul", [
Pure,
AllRanksMatch<["lhs", "rhs", "acc"]>,
AllTypesMatch<["acc", "result"]>,
ShapesCompatibleWithContraction<"lhs", "rhs", "acc">,
IsValidAIE2MatMulShapeAndType<"lhs", "rhs", "acc">
]>,
Arguments<(ins AIE2MatMulLHS:$lhs,
AIE2MatMulRHS:$rhs,
AIE2MatMulACC:$acc)>,
Results<(outs AIE2MatMulACC:$result)> {
let summary = "AIEML matrix-multiply and accummulate";
let description = [{
AMD AIEv2-specific intrinsic that performs a matrix multiplications
between `lhs` and `rhs`, and accumulates the result in `acc`.

Currently, this intrinsic supports the following type combinations:

lhs | rhs | Accumulator
:------------------:|:------------------:|:-----------------:
`vector<4x16xi8>` | `vector<16x8xi4>` | `vector<4x8xi32>`
`vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>`
`vector<4x4xi16>` | `vector<4x8xi8>` | `vector<4x8xi32>`
`vector<4x2xi16>` | `vector<2x8xi16>` | `vector<4x8xi32>`
`vector<2x8xi16>` | `vector<8x8xi8>` | `vector<2x8xi64>`
`vector<4x8xi16>` | `vector<8x4xi8>` | `vector<4x4xi64>`
`vector<2x4xi16>` | `vector<4x8xi16>` | `vector<2x8xi64>`
`vector<4x4xi16>` | `vector<4x4xi16>` | `vector<4x4xi64>`
`vector<4x2xi32>` | `vector<2x4xi16>` | `vector<4x4xi64>`
`vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>`
}];
let assemblyFormat = [{$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,`
type($rhs) `into` type($acc)}];
let hasVerifier = 0;
}

#endif // AIEVEC_OPS
158 changes: 158 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 AMD Inc.
//

//===----------------------------------------------------------------------===//
// Extra type constraint definitions for AIEVec operations.
//===----------------------------------------------------------------------===//

include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"

def I4 : I<4>;

class TypeShape<string name> :
StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShape<list<int> shape> :
CPred<TypeShape<"_self">.result # " == ArrayRef<int64_t>({" #
!interleave(shape, ", ") # "})">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfElementType<Type type> :
SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShapeAndType<list<int> shape, Type type> :
Type<And<[VectorOfShape<shape>,
VectorOfElementType<type>]>,
"vector of shape <" # !interleave(shape, "x") # "> and",
"::mlir::VectorType">;


// Notice: These predicate definitions assume that the type has been verified to
// be a ShapedType
def AIE2MatMulOperand_4x16xi8 : VectorOfShapeAndType<[4, 16], I8>;
def AIE2MatMulOperand_4x8xi8 : VectorOfShapeAndType<[4, 8], I8>;
def AIE2MatMulOperand_4x4xi16 : VectorOfShapeAndType<[4, 4], I16>;
def AIE2MatMulOperand_4x2xi16 : VectorOfShapeAndType<[4, 2], I16>;
def AIE2MatMulOperand_2x8xi16 : VectorOfShapeAndType<[2, 8], I16>;
def AIE2MatMulOperand_4x8xi16 : VectorOfShapeAndType<[4, 8], I16>;
def AIE2MatMulOperand_2x4xi16 : VectorOfShapeAndType<[2, 4], I16>;
def AIE2MatMulOperand_4x2xi32 : VectorOfShapeAndType<[4, 2], I32>;
def AIE2MatMulOperand_4x8xbf16 : VectorOfShapeAndType<[4, 8], BF16>;

def AIE2MatMulOperand_16x8xi4 : VectorOfShapeAndType<[16, 8], I4>;
def AIE2MatMulOperand_8x8xi8 : VectorOfShapeAndType<[8, 8], I8>;
def AIE2MatMulOperand_8x4xi8 : VectorOfShapeAndType<[8, 4], I8>;
def AIE2MatMulOperand_8x4xbf16 : VectorOfShapeAndType<[8, 4], BF16>;

def AIE2MatMulOperand_4x8xi32 : VectorOfShapeAndType<[4, 8], I32>;
def AIE2MatMulOperand_4x4xi32 : VectorOfShapeAndType<[4, 4], I32>;
def AIE2MatMulOperand_2x8xi64 : VectorOfShapeAndType<[2, 8], I64>;
def AIE2MatMulOperand_4x4xi64 : VectorOfShapeAndType<[4, 4], I64>;
def AIE2MatMulOperand_4x4xf32 : VectorOfShapeAndType<[4, 4], F32>;

def AIE2MatMulLHS :
AnyTypeOf<[AIE2MatMulOperand_4x16xi8,
AIE2MatMulOperand_4x8xi8,
AIE2MatMulOperand_4x4xi16,
AIE2MatMulOperand_4x2xi16,
AIE2MatMulOperand_2x8xi16,
AIE2MatMulOperand_4x8xi16,
AIE2MatMulOperand_2x4xi16,
AIE2MatMulOperand_4x2xi32,
AIE2MatMulOperand_4x8xbf16],
"a vector compatible with a lhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulRHS :
AnyTypeOf<[AIE2MatMulOperand_16x8xi4,
AIE2MatMulOperand_8x8xi8,
AIE2MatMulOperand_4x8xi8,
AIE2MatMulOperand_2x8xi16,
AIE2MatMulOperand_8x4xi8,
AIE2MatMulOperand_4x8xi16,
AIE2MatMulOperand_4x4xi16,
AIE2MatMulOperand_2x4xi16,
AIE2MatMulOperand_8x4xbf16],
"a vector compatible with a rhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulACC :
AnyTypeOf<[AIE2MatMulOperand_4x8xi32,
AIE2MatMulOperand_4x4xi32,
AIE2MatMulOperand_2x8xi64,
AIE2MatMulOperand_4x4xi64,
AIE2MatMulOperand_4x4xf32],
"a vector compatible with an accumulator of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

class ShapeDimsMatch<string lhs, int ld, string rhs, int rd> :
CPred<Shape<lhs>.result # "[" # ld # "] == " #
Shape<rhs>.result # "[" # rd # "]">;

class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> :
PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc #
"] is a valid contraction",
And<[ShapeDimsMatch<lhs, 1, rhs, 0>,
ShapeDimsMatch<lhs, 0, acc, 0>,
ShapeDimsMatch<rhs, 1, acc, 1>]>>;

class VectorType<string name> : StrFunc<"cast<VectorType>($" # name #
".getType())">;

class VectorTypesMatch<string op1, Type t1,
string op2, Type t2,
string op3, Type t3> :
And<[SubstLeaves<"$_self", VectorType<op1>.result, t1.predicate>,
SubstLeaves<"$_self", VectorType<op2>.result, t2.predicate>,
SubstLeaves<"$_self", VectorType<op3>.result, t3.predicate>]>;

class IsValidAIE2MatMulShapeAndType<string lhs, string rhs, string acc> :
PredOpTrait<lhs # " x " # rhs # " = " # acc # " is a valid AIE2 " #
"matrix-multiply and accumulate op",
Or<[VectorTypesMatch<lhs, AIE2MatMulOperand_4x16xi8,
rhs, AIE2MatMulOperand_16x8xi4,
acc, AIE2MatMulOperand_4x8xi32>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xi8,
rhs, AIE2MatMulOperand_8x8xi8,
acc, AIE2MatMulOperand_4x8xi32>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x4xi16,
rhs, AIE2MatMulOperand_4x8xi8,
acc, AIE2MatMulOperand_4x8xi32>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x2xi16,
rhs, AIE2MatMulOperand_2x8xi16,
acc, AIE2MatMulOperand_4x8xi32>,

VectorTypesMatch<lhs, AIE2MatMulOperand_2x8xi16,
rhs, AIE2MatMulOperand_8x8xi8,
acc, AIE2MatMulOperand_2x8xi64>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xi16,
rhs, AIE2MatMulOperand_8x4xi8,
acc, AIE2MatMulOperand_4x4xi64>,
VectorTypesMatch<lhs, AIE2MatMulOperand_2x4xi16,
rhs, AIE2MatMulOperand_4x8xi16,
acc, AIE2MatMulOperand_2x8xi64>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x4xi16,
rhs, AIE2MatMulOperand_4x4xi16,
acc, AIE2MatMulOperand_4x4xi64>,
VectorTypesMatch<lhs, AIE2MatMulOperand_4x2xi32,
rhs, AIE2MatMulOperand_2x4xi16,
acc, AIE2MatMulOperand_4x4xi64>,

VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xbf16,
rhs, AIE2MatMulOperand_8x4xbf16,
acc, AIE2MatMulOperand_4x4xf32>]>>;
2 changes: 1 addition & 1 deletion include/aie/Dialect/AIEVec/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
# (c) Copyright 2022 Xilinx Inc.

add_mlir_dialect(AIEVecOps aievec)
add_mlir_doc(AIEVecOps AIEVecDialect ./ -gen-dialect-doc)
add_mlir_doc(AIEVecOps AIEVecDialect ./ -gen-dialect-doc -dialect=aievec)
39 changes: 39 additions & 0 deletions test/dialect/AIEVec/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics

func.func @invalidElementType(%A : vector<4x8xf16>, %B : vector<8x4xf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x8xf16>'}}
%0 = aievec.matmul %A, %B, %C : vector<4x8xf16>, vector<8x4xf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

func.func @invalidShape(%A : vector<4x4xbf16>, %B : vector<4x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x4xbf16>'}}
%0 = aievec.matmul %A, %B, %C : vector<4x4xbf16>, vector<4x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

func.func @invalidContraction(%A : vector<2x4xi16>, %B : vector<2x8xi16>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
// expected-error @+1 {{op failed to verify that [lhs x rhs = acc] is a valid contraction}}
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<2x8xi16>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

func.func @invalidAccumulatorType(%A : vector<2x4xi16>, %B : vector<4x8xi16>,
%C : vector<2x8xi32>) -> vector<2x8xi32> {
// expected-error @+1 {{op operand #2 must be a vector compatible with an accumulator of matrix-multiply and accumulate, but got 'vector<2x8xi32>'}}
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16>
into vector<2x8xi32>
return %0 : vector<2x8xi32>
}
Loading

0 comments on commit 4dd1644

Please sign in to comment.