Skip to content

Commit

Permalink
[aievec] Add AIEml intrinsic for bf16/f32 types
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain committed Oct 20, 2023
1 parent 4c853c4 commit c349f56
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 0 deletions.
36 changes: 36 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define AIEVEC_OPS

include "aie/Dialect/AIEVec/IR/AIEVecTypes.td"
include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Base class for AIE dialect ops.
Expand Down Expand Up @@ -727,4 +728,39 @@ def AIEVec_ExtElemOp:
let assemblyFormat = "$source `,` $index attr-dict `:` type($source) `,` type($index) `,` type($result)";
let hasVerifier = 0;
}

def AIEVec_MatMulOp:
AIEVec_Op<"matmul", [
Pure,
AllRanksMatch<["lhs", "rhs", "acc"]>,
AllElementTypesMatch<["lhs", "rhs"]>,
AllTypesMatch<["acc", "result"]>,
ShapesCompatibleWithContraction<"lhs", "rhs", "acc">,
IsValidAIE2AccumulatorType<"acc", "lhs">
]>,
Arguments<(ins AIE2MatMulLHS:$lhs,
AIE2MatMulRHS:$rhs,
AIE2MatMulACC:$acc)>,
Results<(outs AIE2MatMulACC:$result)> {
let summary = "AIEML matrix-multiply and accummulate";
let description = [{
AMD AIEv2-specific intrinsic that performs a matrix multiplications
between `lhs` and `rhs`, and accumulates the result in `acc`.

Currently, this intrinsic supports the following type combinations:

lhs | rhs | Accumulator
:------------------:|:------------------:|:-----------------:
`vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>`
`vector<4x8xf32>` | `vector<8x4xf32>` | `vector<4x4xf32>`
`vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>`
`vector<2x4xi16>` | `vector<4x8xi16>` | `vector<2x8xi64>`
`vector<4x2xi16>` | `vector<2x8xi16>` | `vector<4x8xi32>`
`vector<4x2xi32>` | `vector<2x4xi32>` | `vector<4x4xi64>`
}];
let assemblyFormat = [{$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,`
type($rhs) `into` type($acc)}];
let hasVerifier = 0;
}

#endif // AIEVEC_OPS
129 changes: 129 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 AMD Inc.
//

//===----------------------------------------------------------------------===//
// Extra type constraint definitions for AIEVec operations.
//===----------------------------------------------------------------------===//

include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"

class TypeShape<string name> :
StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShape<list<int> shape> :
CPred<TypeShape<"_self">.result # " == ArrayRef<int64_t>({" #
!interleave(shape, ", ") # "})">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfElementType<Type type> :
SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShapeAndType<list<int> shape, Type type> :
Type<And<[VectorOfShape<shape>,
VectorOfElementType<type>]>,
"vector of shape <" # !interleave(shape, "x") # "> and",
"::mlir::VectorType">;


// Notice: These predicate definitions assume that the type has been verified to
// be a ShapedType
def AIE2BF16MatMulLHSOperand : VectorOfShapeAndType<[4, 8], BF16>;
def AIE2BF16MatMulRHSOperand : VectorOfShapeAndType<[8, 4], BF16>;
def AIE2BF16MatMulACCOperand : VectorOfShapeAndType<[4, 4], F32>;

def AIE2F32MatMulLHSOperand : VectorOfShapeAndType<[4, 8], F32>;
def AIE2F32MatMulRHSOperand : VectorOfShapeAndType<[8, 4], F32>;
def AIE2F32MatMulACCOperand : VectorOfShapeAndType<[4, 4], F32>;

def AIE2I8MatMulLHSOperand : VectorOfShapeAndType<[4, 8], I8>;
def AIE2I8MatMulRHSOperand : VectorOfShapeAndType<[8, 8], I8>;
def AIE2I8MatMulACCOperand : VectorOfShapeAndType<[4, 8], I32>;

def AIE2I16aMatMulLHSOperand : VectorOfShapeAndType<[2, 4], I16>;
def AIE2I16aMatMulRHSOperand : VectorOfShapeAndType<[4, 8], I16>;
def AIE2I16aMatMulACCOperand : VectorOfShapeAndType<[2, 8], I64>;

def AIE2I16bMatMulLHSOperand : VectorOfShapeAndType<[4, 2], I16>;
def AIE2I16bMatMulRHSOperand : VectorOfShapeAndType<[2, 8], I16>;
def AIE2I16bMatMulACCOperand : VectorOfShapeAndType<[4, 8], I32>;

def AIE2I32MatMulLHSOperand : VectorOfShapeAndType<[4, 2], I32>;
def AIE2I32MatMulRHSOperand : VectorOfShapeAndType<[2, 4], I32>;
def AIE2I32MatMulACCOperand : VectorOfShapeAndType<[4, 4], I64>;

def AIE2MatMulLHS :
Type<And<[IsVectorTypePred,
Or<!foreach(pred, [AIE2BF16MatMulLHSOperand,
AIE2F32MatMulLHSOperand,
AIE2I8MatMulLHSOperand,
AIE2I16aMatMulLHSOperand,
AIE2I16bMatMulLHSOperand,
AIE2I32MatMulLHSOperand],
pred.predicate)>]>,
"a vector compatible with a lhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulRHS :
Type<And<[IsVectorTypePred,
Or<!foreach(pred, [AIE2BF16MatMulRHSOperand,
AIE2F32MatMulRHSOperand,
AIE2I8MatMulRHSOperand,
AIE2I16aMatMulRHSOperand,
AIE2I16bMatMulRHSOperand,
AIE2I32MatMulRHSOperand],
pred.predicate)>]>,
"a vector compatible with a rhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulACC :
Type<And<[IsVectorTypePred,
Or<!foreach(pred, [AIE2BF16MatMulACCOperand,
AIE2F32MatMulACCOperand,
AIE2I8MatMulACCOperand,
AIE2I16aMatMulACCOperand,
AIE2I16bMatMulACCOperand,
AIE2I32MatMulACCOperand],
pred.predicate)>]>,
"a vector compatible with an accumulator operand of matrix-multiply and"
# " accumulate",
"::mlir::VectorType">;

class ShapeDimsMatch<string lhs, int ld, string rhs, int rd> :
CPred<Shape<lhs>.result # "[" # ld # "] == " #
Shape<rhs>.result # "[" # rd # "]">;

class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> :
PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc #
"] is a valid contraction",
And<[ShapeDimsMatch<lhs, 1, rhs, 0>,
ShapeDimsMatch<lhs, 0, acc, 0>,
ShapeDimsMatch<rhs, 1, acc, 1>]>>;

class VectorElementTypesMatch<string op1, Type t1, string op2, Type t2> :
And<[SubstLeaves<"$_self", ElementType<op1>.result, t1.predicate>,
SubstLeaves<"$_self", ElementType<op2>.result, t2.predicate>]>;

class IsValidAIE2AccumulatorType<string acc, string operand> :
PredOpTrait<acc # " element type is a valid accumulator type for the element"
# " type of " # operand,
Or<[VectorElementTypesMatch<operand, BF16, acc, F32>,
VectorElementTypesMatch<operand, F32, acc, F32>,
VectorElementTypesMatch<operand, I8, acc, I32>,
VectorElementTypesMatch<operand, I16, acc, I64>,
VectorElementTypesMatch<operand, I16, acc, I32>,
VectorElementTypesMatch<operand, I32, acc, I64>]>
>;
95 changes: 95 additions & 0 deletions test/dialect/AIEVec/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @matmul_bf16
// CHECK-SAME: %[[A:.*]]: vector<4x8xbf16>
// CHECK-SAME: %[[B:.*]]: vector<8x4xbf16>
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xbf16>, vector<8x4xbf16> into vector<4x4xf32>
// CHECK: return %[[RES]] : vector<4x4xf32>
func.func @matmul_bf16(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

// CHECK-LABEL: @matmul_f32
// CHECK-SAME: %[[A:.*]]: vector<4x8xf32>
// CHECK-SAME: %[[B:.*]]: vector<8x4xf32>
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xf32>, vector<8x4xf32> into vector<4x4xf32>
// CHECK: return %[[RES]] : vector<4x4xf32>
func.func @matmul_f32(%A : vector<4x8xf32>, %B : vector<8x4xf32>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xf32>, vector<8x4xf32>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

// CHECK-LABEL: @matmul_i8
// CHECK-SAME: %[[A:.*]]: vector<4x8xi8>
// CHECK-SAME: %[[B:.*]]: vector<8x8xi8>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xi8>, vector<8x8xi8> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i8(%A : vector<4x8xi8>, %B : vector<8x8xi8>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i16a
// CHECK-SAME: %[[A:.*]]: vector<2x4xi16>
// CHECK-SAME: %[[B:.*]]: vector<4x8xi16>
// CHECK-SAME: %[[C:.*]]: vector<2x8xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<2x4xi16>, vector<4x8xi16> into vector<2x8xi64>
// CHECK: return %[[RES]] : vector<2x8xi64>
func.func @matmul_i16a(%A : vector<2x4xi16>, %B : vector<4x8xi16>,
%C : vector<2x8xi64>) -> vector<2x8xi64> {
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16>
into vector<2x8xi64>
return %0 : vector<2x8xi64>
}

// -----

// CHECK-LABEL: @matmul_i16b
// CHECK-SAME: %[[A:.*]]: vector<4x2xi16>
// CHECK-SAME: %[[B:.*]]: vector<2x8xi16>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x2xi16>, vector<2x8xi16> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i16b(%A : vector<4x2xi16>, %B : vector<2x8xi16>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x2xi16>, vector<2x8xi16>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i32
// CHECK-SAME: %[[A:.*]]: vector<4x2xi32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xi32>
// CHECK-SAME: %[[C:.*]]: vector<4x4xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x2xi32>, vector<2x4xi32> into vector<4x4xi64>
// CHECK: return %[[RES]] : vector<4x4xi64>
func.func @matmul_i32(%A : vector<4x2xi32>, %B : vector<2x4xi32>,
%C : vector<4x4xi64>) -> vector<4x4xi64> {
%0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi32>
into vector<4x4xi64>
return %0 : vector<4x4xi64>
}

0 comments on commit c349f56

Please sign in to comment.