-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[aievec] Add AIEml intrinsic for bf16/f32 types
- Loading branch information
Showing
3 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
// (c) Copyright 2023 AMD Inc. | ||
// | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Extra type constraint definitions for AIEVec operations. | ||
//===----------------------------------------------------------------------===// | ||
|
||
include "mlir/IR/BuiltinTypes.td" | ||
include "mlir/IR/OpBase.td" | ||
|
||
class TypeShape<string name> : | ||
StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfShape<list<int> shape> : | ||
CPred<TypeShape<"_self">.result # " == ArrayRef<int64_t>({" # | ||
!interleave(shape, ", ") # "})">; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfElementType<Type type> : | ||
SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfShapeAndType<list<int> shape, Type type> : | ||
Type<And<[VectorOfShape<shape>, | ||
VectorOfElementType<type>]>, | ||
"vector of shape <" # !interleave(shape, "x") # "> and", | ||
"::mlir::VectorType">; | ||
|
||
|
||
// Notice: These predicate definitions assume that the type has been verified to | ||
// be a ShapedType | ||
def AIE2BF16MatMulLHSOperand : VectorOfShapeAndType<[4, 8], BF16>; | ||
def AIE2BF16MatMulRHSOperand : VectorOfShapeAndType<[8, 4], BF16>; | ||
def AIE2BF16MatMulACCOperand : VectorOfShapeAndType<[4, 4], F32>; | ||
|
||
def AIE2F32MatMulLHSOperand : VectorOfShapeAndType<[4, 8], F32>; | ||
def AIE2F32MatMulRHSOperand : VectorOfShapeAndType<[8, 4], F32>; | ||
def AIE2F32MatMulACCOperand : VectorOfShapeAndType<[4, 4], F32>; | ||
|
||
def AIE2I8MatMulLHSOperand : VectorOfShapeAndType<[4, 8], I8>; | ||
def AIE2I8MatMulRHSOperand : VectorOfShapeAndType<[8, 8], I8>; | ||
def AIE2I8MatMulACCOperand : VectorOfShapeAndType<[4, 8], I32>; | ||
|
||
def AIE2I16aMatMulLHSOperand : VectorOfShapeAndType<[2, 4], I16>; | ||
def AIE2I16aMatMulRHSOperand : VectorOfShapeAndType<[4, 8], I16>; | ||
def AIE2I16aMatMulACCOperand : VectorOfShapeAndType<[2, 8], I64>; | ||
|
||
def AIE2I16bMatMulLHSOperand : VectorOfShapeAndType<[4, 2], I16>; | ||
def AIE2I16bMatMulRHSOperand : VectorOfShapeAndType<[2, 8], I16>; | ||
def AIE2I16bMatMulACCOperand : VectorOfShapeAndType<[4, 8], I32>; | ||
|
||
def AIE2I32MatMulLHSOperand : VectorOfShapeAndType<[4, 2], I32>; | ||
def AIE2I32MatMulRHSOperand : VectorOfShapeAndType<[2, 4], I32>; | ||
def AIE2I32MatMulACCOperand : VectorOfShapeAndType<[4, 4], I64>; | ||
|
||
def AIE2MatMulLHS : | ||
Type<And<[IsVectorTypePred, | ||
Or<!foreach(pred, [AIE2BF16MatMulLHSOperand, | ||
AIE2F32MatMulLHSOperand, | ||
AIE2I8MatMulLHSOperand, | ||
AIE2I16aMatMulLHSOperand, | ||
AIE2I16bMatMulLHSOperand, | ||
AIE2I32MatMulLHSOperand], | ||
pred.predicate)>]>, | ||
"a vector compatible with a lhs operand of matrix-multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MatMulRHS : | ||
Type<And<[IsVectorTypePred, | ||
Or<!foreach(pred, [AIE2BF16MatMulRHSOperand, | ||
AIE2F32MatMulRHSOperand, | ||
AIE2I8MatMulRHSOperand, | ||
AIE2I16aMatMulRHSOperand, | ||
AIE2I16bMatMulRHSOperand, | ||
AIE2I32MatMulRHSOperand], | ||
pred.predicate)>]>, | ||
"a vector compatible with a rhs operand of matrix-multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MatMulACC : | ||
Type<And<[IsVectorTypePred, | ||
Or<!foreach(pred, [AIE2BF16MatMulACCOperand, | ||
AIE2F32MatMulACCOperand, | ||
AIE2I8MatMulACCOperand, | ||
AIE2I16aMatMulACCOperand, | ||
AIE2I16bMatMulACCOperand, | ||
AIE2I32MatMulACCOperand], | ||
pred.predicate)>]>, | ||
"a vector compatible with an accumulator operand of matrix-multiply and" | ||
# " accumulate", | ||
"::mlir::VectorType">; | ||
|
||
class ShapeDimsMatch<string lhs, int ld, string rhs, int rd> : | ||
CPred<Shape<lhs>.result # "[" # ld # "] == " # | ||
Shape<rhs>.result # "[" # rd # "]">; | ||
|
||
class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> : | ||
PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc # | ||
"] is a valid contraction", | ||
And<[ShapeDimsMatch<lhs, 1, rhs, 0>, | ||
ShapeDimsMatch<lhs, 0, acc, 0>, | ||
ShapeDimsMatch<rhs, 1, acc, 1>]>>; | ||
|
||
class VectorElementTypesMatch<string op1, Type t1, string op2, Type t2> : | ||
And<[SubstLeaves<"$_self", ElementType<op1>.result, t1.predicate>, | ||
SubstLeaves<"$_self", ElementType<op2>.result, t2.predicate>]>; | ||
|
||
class IsValidAIE2AccumulatorType<string acc, string operand> : | ||
PredOpTrait<acc # " element type is a valid accumulator type for the element" | ||
# " type of " # operand, | ||
Or<[VectorElementTypesMatch<operand, BF16, acc, F32>, | ||
VectorElementTypesMatch<operand, F32, acc, F32>, | ||
VectorElementTypesMatch<operand, I8, acc, I32>, | ||
VectorElementTypesMatch<operand, I16, acc, I64>, | ||
VectorElementTypesMatch<operand, I16, acc, I32>, | ||
VectorElementTypesMatch<operand, I32, acc, I64>]> | ||
>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// RUN: aie-opt %s -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: @matmul_bf16 | ||
// CHECK-SAME: %[[A:.*]]: vector<4x8xbf16> | ||
// CHECK-SAME: %[[B:.*]]: vector<8x4xbf16> | ||
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<4x8xbf16>, vector<8x4xbf16> into vector<4x4xf32> | ||
// CHECK: return %[[RES]] : vector<4x4xf32> | ||
func.func @matmul_bf16(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>, | ||
%C : vector<4x4xf32>) -> vector<4x4xf32> { | ||
%0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16> | ||
into vector<4x4xf32> | ||
return %0 : vector<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @matmul_f32 | ||
// CHECK-SAME: %[[A:.*]]: vector<4x8xf32> | ||
// CHECK-SAME: %[[B:.*]]: vector<8x4xf32> | ||
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<4x8xf32>, vector<8x4xf32> into vector<4x4xf32> | ||
// CHECK: return %[[RES]] : vector<4x4xf32> | ||
func.func @matmul_f32(%A : vector<4x8xf32>, %B : vector<8x4xf32>, | ||
%C : vector<4x4xf32>) -> vector<4x4xf32> { | ||
%0 = aievec.matmul %A, %B, %C : vector<4x8xf32>, vector<8x4xf32> | ||
into vector<4x4xf32> | ||
return %0 : vector<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @matmul_i8 | ||
// CHECK-SAME: %[[A:.*]]: vector<4x8xi8> | ||
// CHECK-SAME: %[[B:.*]]: vector<8x8xi8> | ||
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<4x8xi8>, vector<8x8xi8> into vector<4x8xi32> | ||
// CHECK: return %[[RES]] : vector<4x8xi32> | ||
func.func @matmul_i8(%A : vector<4x8xi8>, %B : vector<8x8xi8>, | ||
%C : vector<4x8xi32>) -> vector<4x8xi32> { | ||
%0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8> | ||
into vector<4x8xi32> | ||
return %0 : vector<4x8xi32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @matmul_i16a | ||
// CHECK-SAME: %[[A:.*]]: vector<2x4xi16> | ||
// CHECK-SAME: %[[B:.*]]: vector<4x8xi16> | ||
// CHECK-SAME: %[[C:.*]]: vector<2x8xi64> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<2x4xi16>, vector<4x8xi16> into vector<2x8xi64> | ||
// CHECK: return %[[RES]] : vector<2x8xi64> | ||
func.func @matmul_i16a(%A : vector<2x4xi16>, %B : vector<4x8xi16>, | ||
%C : vector<2x8xi64>) -> vector<2x8xi64> { | ||
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16> | ||
into vector<2x8xi64> | ||
return %0 : vector<2x8xi64> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @matmul_i16b | ||
// CHECK-SAME: %[[A:.*]]: vector<4x2xi16> | ||
// CHECK-SAME: %[[B:.*]]: vector<2x8xi16> | ||
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<4x2xi16>, vector<2x8xi16> into vector<4x8xi32> | ||
// CHECK: return %[[RES]] : vector<4x8xi32> | ||
func.func @matmul_i16b(%A : vector<4x2xi16>, %B : vector<2x8xi16>, | ||
%C : vector<4x8xi32>) -> vector<4x8xi32> { | ||
%0 = aievec.matmul %A, %B, %C : vector<4x2xi16>, vector<2x8xi16> | ||
into vector<4x8xi32> | ||
return %0 : vector<4x8xi32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @matmul_i32 | ||
// CHECK-SAME: %[[A:.*]]: vector<4x2xi32> | ||
// CHECK-SAME: %[[B:.*]]: vector<2x4xi32> | ||
// CHECK-SAME: %[[C:.*]]: vector<4x4xi64> | ||
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : | ||
// CHECK-SAME: vector<4x2xi32>, vector<2x4xi32> into vector<4x4xi64> | ||
// CHECK: return %[[RES]] : vector<4x4xi64> | ||
func.func @matmul_i32(%A : vector<4x2xi32>, %B : vector<2x4xi32>, | ||
%C : vector<4x4xi64>) -> vector<4x4xi64> { | ||
%0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi32> | ||
into vector<4x4xi64> | ||
return %0 : vector<4x4xi64> | ||
} |