-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[aievec] Add AIEml intrinsic for bf16/f32 types
- Loading branch information
Showing
4 changed files
with
417 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
// (c) Copyright 2023 AMD Inc. | ||
// | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Extra type constraint definitions for AIEVec operations. | ||
//===----------------------------------------------------------------------===// | ||
|
||
include "mlir/IR/BuiltinTypes.td" | ||
include "mlir/IR/OpBase.td" | ||
|
||
def I4 : I<4>; | ||
|
||
class TypeShape<string name> : | ||
StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfShape<list<int> shape> : | ||
CPred<TypeShape<"_self">.result # " == ArrayRef<int64_t>({" # | ||
!interleave(shape, ", ") # "})">; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfElementType<Type type> : | ||
SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>; | ||
|
||
// Notice: This predicate class assumes that the type has been verified to be a | ||
// ShapedType | ||
class VectorOfShapeAndType<list<int> shape, Type type> : | ||
Type<And<[VectorOfShape<shape>, | ||
VectorOfElementType<type>]>, | ||
"vector of shape <" # !interleave(shape, "x") # "> and", | ||
"::mlir::VectorType">; | ||
|
||
|
||
// Notice: These predicate definitions assume that the type has been verified to | ||
// be a ShapedType | ||
def AIE2MatMulOperand_4x16xi8 : VectorOfShapeAndType<[4, 16], I8>; | ||
def AIE2MatMulOperand_4x8xi8 : VectorOfShapeAndType<[4, 8], I8>; | ||
def AIE2MatMulOperand_4x4xi16 : VectorOfShapeAndType<[4, 4], I16>; | ||
def AIE2MatMulOperand_4x2xi16 : VectorOfShapeAndType<[4, 2], I16>; | ||
def AIE2MatMulOperand_2x8xi16 : VectorOfShapeAndType<[2, 8], I16>; | ||
def AIE2MatMulOperand_4x8xi16 : VectorOfShapeAndType<[4, 8], I16>; | ||
def AIE2MatMulOperand_2x4xi16 : VectorOfShapeAndType<[2, 4], I16>; | ||
def AIE2MatMulOperand_4x2xi32 : VectorOfShapeAndType<[4, 2], I32>; | ||
def AIE2MatMulOperand_4x8xbf16 : VectorOfShapeAndType<[4, 8], BF16>; | ||
|
||
def AIE2MatMulOperand_16x8xi4 : VectorOfShapeAndType<[16, 8], I4>; | ||
def AIE2MatMulOperand_8x8xi8 : VectorOfShapeAndType<[8, 8], I8>; | ||
def AIE2MatMulOperand_4x4xi8 : VectorOfShapeAndType<[4, 4], I8>; | ||
def AIE2MatMulOperand_8x4xi8 : VectorOfShapeAndType<[8, 4], I8>; | ||
def AIE2MatMulOperand_8x4xbf16 : VectorOfShapeAndType<[8, 4], BF16>; | ||
|
||
def AIE2MatMulOperand_4x8xi32 : VectorOfShapeAndType<[4, 8], I32>; | ||
def AIE2MatMulOperand_4x4xi32 : VectorOfShapeAndType<[4, 4], I32>; | ||
def AIE2MatMulOperand_2x8xi64 : VectorOfShapeAndType<[2, 8], I64>; | ||
def AIE2MatMulOperand_4x4xi64 : VectorOfShapeAndType<[4, 4], I64>; | ||
def AIE2MatMulOperand_4x4xf32 : VectorOfShapeAndType<[4, 4], F32>; | ||
|
||
def AIE2MatMulLHS : | ||
AnyTypeOf<[AIE2MatMulOperand_4x16xi8, | ||
AIE2MatMulOperand_4x8xi8, | ||
AIE2MatMulOperand_4x4xi16, | ||
AIE2MatMulOperand_4x2xi16, | ||
AIE2MatMulOperand_2x8xi16, | ||
AIE2MatMulOperand_4x8xi16, | ||
AIE2MatMulOperand_2x4xi16, | ||
AIE2MatMulOperand_4x2xi32, | ||
AIE2MatMulOperand_4x8xbf16], | ||
"a vector compatible with a lhs operand of matrix-multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MatMulRHS : | ||
AnyTypeOf<[AIE2MatMulOperand_16x8xi4, | ||
AIE2MatMulOperand_8x8xi8, | ||
AIE2MatMulOperand_4x8xi8, | ||
AIE2MatMulOperand_2x8xi16, | ||
AIE2MatMulOperand_4x4xi8, | ||
AIE2MatMulOperand_8x4xi8, | ||
AIE2MatMulOperand_4x8xi16, | ||
AIE2MatMulOperand_4x4xi16, | ||
AIE2MatMulOperand_2x4xi16, | ||
AIE2MatMulOperand_8x4xbf16], | ||
"a vector compatible with a rhs operand of matrix-multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MatMulACC : | ||
AnyTypeOf<[AIE2MatMulOperand_4x8xi32, | ||
AIE2MatMulOperand_4x4xi32, | ||
AIE2MatMulOperand_2x8xi64, | ||
AIE2MatMulOperand_4x4xi64, | ||
AIE2MatMulOperand_4x4xf32], | ||
"a vector compatible with an accumulator of matrix-multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
class ShapeDimsMatch<string lhs, int ld, string rhs, int rd> : | ||
CPred<Shape<lhs>.result # "[" # ld # "] == " # | ||
Shape<rhs>.result # "[" # rd # "]">; | ||
|
||
class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> : | ||
PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc # | ||
"] is a valid contraction", | ||
And<[ShapeDimsMatch<lhs, 1, rhs, 0>, | ||
ShapeDimsMatch<lhs, 0, acc, 0>, | ||
ShapeDimsMatch<rhs, 1, acc, 1>]>>; | ||
|
||
class VectorType<string name> : StrFunc<"cast<VectorType>($" # name # | ||
".getType())">; | ||
|
||
class VectorTypesMatch<string op1, Type t1, | ||
string op2, Type t2, | ||
string op3, Type t3> : | ||
And<[SubstLeaves<"$_self", VectorType<op1>.result, t1.predicate>, | ||
SubstLeaves<"$_self", VectorType<op2>.result, t2.predicate>, | ||
SubstLeaves<"$_self", VectorType<op3>.result, t3.predicate>]>; | ||
|
||
class IsValidAIE2MatMulShapeAndType<string lhs, string rhs, string acc> : | ||
PredOpTrait<lhs # " x " # rhs # " = " # acc # " is a valid AIE2 " # | ||
"matrix-multiply and accumulate op", | ||
Or<[VectorTypesMatch<lhs, AIE2MatMulOperand_4x16xi8, | ||
rhs, AIE2MatMulOperand_16x8xi4, | ||
acc, AIE2MatMulOperand_4x8xi32>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xi8, | ||
rhs, AIE2MatMulOperand_8x8xi8, | ||
acc, AIE2MatMulOperand_4x8xi32>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x4xi16, | ||
rhs, AIE2MatMulOperand_4x8xi8, | ||
acc, AIE2MatMulOperand_4x8xi32>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x2xi16, | ||
rhs, AIE2MatMulOperand_2x8xi16, | ||
acc, AIE2MatMulOperand_4x8xi32>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x4xi16, | ||
rhs, AIE2MatMulOperand_4x4xi8, | ||
acc, AIE2MatMulOperand_4x4xi32>, | ||
|
||
VectorTypesMatch<lhs, AIE2MatMulOperand_2x8xi16, | ||
rhs, AIE2MatMulOperand_8x8xi8, | ||
acc, AIE2MatMulOperand_2x8xi64>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xi16, | ||
rhs, AIE2MatMulOperand_8x4xi8, | ||
acc, AIE2MatMulOperand_4x4xi64>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_2x4xi16, | ||
rhs, AIE2MatMulOperand_4x8xi16, | ||
acc, AIE2MatMulOperand_2x8xi64>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x4xi16, | ||
rhs, AIE2MatMulOperand_4x4xi16, | ||
acc, AIE2MatMulOperand_4x4xi64>, | ||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x2xi32, | ||
rhs, AIE2MatMulOperand_2x4xi16, | ||
acc, AIE2MatMulOperand_4x4xi64>, | ||
|
||
VectorTypesMatch<lhs, AIE2MatMulOperand_4x8xbf16, | ||
rhs, AIE2MatMulOperand_8x4xbf16, | ||
acc, AIE2MatMulOperand_4x4xf32>]>>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// RUN: aie-opt %s -split-input-file -verify-diagnostics | ||
|
||
func.func @invalidElementType(%A : vector<4x8xf16>, %B : vector<8x4xf16>, | ||
%C : vector<4x4xf32>) -> vector<4x4xf32> { | ||
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x8xf16>'}} | ||
%0 = aievec.matmul %A, %B, %C : vector<4x8xf16>, vector<8x4xf16> | ||
into vector<4x4xf32> | ||
return %0 : vector<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @invalidShape(%A : vector<4x4xbf16>, %B : vector<4x4xbf16>, | ||
%C : vector<4x4xf32>) -> vector<4x4xf32> { | ||
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x4xbf16>'}} | ||
%0 = aievec.matmul %A, %B, %C : vector<4x4xbf16>, vector<4x4xbf16> | ||
into vector<4x4xf32> | ||
return %0 : vector<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @invalidContraction(%A : vector<2x4xi16>, %B : vector<2x8xi16>, | ||
%C : vector<4x8xi32>) -> vector<4x8xi32> { | ||
// expected-error @+1 {{op failed to verify that [lhs x rhs = acc] is a valid contraction}} | ||
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<2x8xi16> | ||
into vector<4x8xi32> | ||
return %0 : vector<4x8xi32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @invalidAccumulatorType(%A : vector<2x4xi16>, %B : vector<4x8xi16>, | ||
%C : vector<2x8xi32>) -> vector<2x8xi32> { | ||
// expected-error @+1 {{op operand #2 must be a vector compatible with an accumulator of matrix-multiply and accumulate, but got 'vector<2x8xi32>'}} | ||
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16> | ||
into vector<2x8xi32> | ||
return %0 : vector<2x8xi32> | ||
} |
Oops, something went wrong.