Skip to content

Commit

Permalink
[aievec] Add AIEml intrinsic for bf16/f32 types
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain committed Nov 8, 2023
1 parent 9e037e6 commit 7114fb1
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 2 deletions.
41 changes: 39 additions & 2 deletions include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#ifndef AIEVEC_OPS
#define AIEVEC_OPS

include "aie/Dialect/AIEVec/IR/AIEVecTypes.td"
include "aie/Dialect/AIE/IR/AIEAttrs.td"

include "aie/Dialect/AIEVec/IR/AIEVecTypes.td"
include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Base class for AIE dialect ops.
Expand Down Expand Up @@ -810,4 +810,41 @@ def AIEVec_BandOp:
let hasVerifier = 0;
}

def AIEVec_MatMulOp:
AIEVec_Op<"matmul", [
Pure,
AllRanksMatch<["lhs", "rhs", "acc"]>,
AllTypesMatch<["acc", "result"]>,
ShapesCompatibleWithContraction<"lhs", "rhs", "acc">,
IsValidAIE2MatMulShapeAndType<"lhs", "rhs", "acc">
]>,
Arguments<(ins AIE2MatMulLHS:$lhs,
AIE2MatMulRHS:$rhs,
AIE2MatMulACC:$acc)>,
Results<(outs AIE2MatMulACC:$result)> {
let summary = "AIEML matrix-multiply and accummulate";
let description = [{
AMD AIEv2-specific intrinsic that performs a matrix multiplications
between `lhs` and `rhs`, and accumulates the result in `acc`.

Currently, this intrinsic supports the following type combinations:

lhs | rhs | Accumulator
:------------------:|:------------------:|:-----------------:
`vector<4x16xi8>` | `vector<16x8xi4>` | `vector<4x8xi32>`
`vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>`
`vector<4x4xi16>` | `vector<4x8xi8>` | `vector<4x8xi32>`
`vector<4x2xi16>` | `vector<2x8xi16>` | `vector<4x8xi32>`
`vector<2x8xi16>` | `vector<8x8xi8>` | `vector<2x8xi64>`
`vector<4x8xi16>` | `vector<8x4xi8>` | `vector<4x4xi64>`
`vector<2x4xi16>` | `vector<4x8xi16>` | `vector<2x8xi64>`
`vector<4x4xi16>` | `vector<4x4xi16>` | `vector<4x4xi64>`
`vector<4x2xi32>` | `vector<2x4xi16>` | `vector<4x4xi64>`
`vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>`
}];
let assemblyFormat = [{$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,`
type($rhs) `into` type($acc)}];
let hasVerifier = 0;
}

#endif // AIEVEC_OPS
134 changes: 134 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 AMD Inc.
//

//===----------------------------------------------------------------------===//
// Extra type constraint definitions for AIEVec operations.
//===----------------------------------------------------------------------===//

include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"

def I4 : I<4>;

class TypeShape<string name> :
StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShape<list<int> shape> :
CPred<TypeShape<"_self">.result # " == ArrayRef<int64_t>({" #
!interleave(shape, ", ") # "})">;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfElementType<Type type> :
SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>;

// Notice: This predicate class assumes that the type has been verified to be a
// ShapedType
class VectorOfShapeAndType<list<int> shape, Type type> :
Type<And<[VectorOfShape<shape>,
VectorOfElementType<type>]>,
"vector of shape <" # !interleave(shape, "x") # "> and",
"::mlir::VectorType">;

def AIE2MatMulLHS :
AnyTypeOf<[VectorOfShapeAndType<[4, 16], I8>,
VectorOfShapeAndType<[4, 8], I8>,
VectorOfShapeAndType<[4, 4], I16>,
VectorOfShapeAndType<[4, 2], I16>,
VectorOfShapeAndType<[2, 8], I16>,
VectorOfShapeAndType<[4, 8], I16>,
VectorOfShapeAndType<[2, 4], I16>,
VectorOfShapeAndType<[4, 2], I32>,
VectorOfShapeAndType<[4, 8], BF16>],
"a vector compatible with a lhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulRHS :
AnyTypeOf<[VectorOfShapeAndType<[16, 8], I4>,
VectorOfShapeAndType<[8, 8], I8>,
VectorOfShapeAndType<[4, 8], I8>,
VectorOfShapeAndType<[2, 8], I16>,
VectorOfShapeAndType<[8, 4], I8>,
VectorOfShapeAndType<[4, 8], I16>,
VectorOfShapeAndType<[4, 4], I16>,
VectorOfShapeAndType<[2, 4], I16>,
VectorOfShapeAndType<[8, 4], BF16>],
"a vector compatible with a rhs operand of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MatMulACC :
AnyTypeOf<[VectorOfShapeAndType<[4, 8], I32>,
VectorOfShapeAndType<[4, 4], I32>,
VectorOfShapeAndType<[2, 8], I64>,
VectorOfShapeAndType<[4, 4], I64>,
VectorOfShapeAndType<[4, 4], F32>],
"a vector compatible with an accumulator of matrix-multiply and "
# "accumulate",
"::mlir::VectorType">;

class ShapeDimsMatch<string lhs, int ld, string rhs, int rd> :
CPred<Shape<lhs>.result # "[" # ld # "] == " #
Shape<rhs>.result # "[" # rd # "]">;

class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> :
PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc #
"] is a valid contraction",
And<[ShapeDimsMatch<lhs, 1, rhs, 0>,
ShapeDimsMatch<lhs, 0, acc, 0>,
ShapeDimsMatch<rhs, 1, acc, 1>]>>;

class VectorType<string name> : StrFunc<"cast<VectorType>($" # name #
".getType())">;

class VectorTypesMatch<string op1, Type t1,
string op2, Type t2,
string op3, Type t3> :
And<[SubstLeaves<"$_self", VectorType<op1>.result, t1.predicate>,
SubstLeaves<"$_self", VectorType<op2>.result, t2.predicate>,
SubstLeaves<"$_self", VectorType<op3>.result, t3.predicate>]>;

class IsValidAIE2MatMulShapeAndType<string lhs, string rhs, string acc> :
PredOpTrait<lhs # " x " # rhs # " = " # acc # " is a valid AIE2 " #
"matrix-multiply and accumulate op",
Or<[VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 16], I8>,
rhs, VectorOfShapeAndType<[16, 8], I4>,
acc, VectorOfShapeAndType<[4, 8], I32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 8], I8>,
rhs, VectorOfShapeAndType<[8, 8], I8>,
acc, VectorOfShapeAndType<[4, 8], I32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 4], I16>,
rhs, VectorOfShapeAndType<[4, 8], I8>,
acc, VectorOfShapeAndType<[4, 8], I32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 2], I16>,
rhs, VectorOfShapeAndType<[2, 8], I16>,
acc, VectorOfShapeAndType<[4, 8], I32>>,

VectorTypesMatch<lhs, VectorOfShapeAndType<[2, 8], I16>,
rhs, VectorOfShapeAndType<[8, 8], I8>,
acc, VectorOfShapeAndType<[2, 8], I64>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 8], I16>,
rhs, VectorOfShapeAndType<[8, 4], I8>,
acc, VectorOfShapeAndType<[4, 4], I64>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[2, 4], I16>,
rhs, VectorOfShapeAndType<[4, 8], I16>,
acc, VectorOfShapeAndType<[2, 8], I64>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 4], I16>,
rhs, VectorOfShapeAndType<[4, 4], I16>,
acc, VectorOfShapeAndType<[4, 4], I64>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 2], I32>,
rhs, VectorOfShapeAndType<[2, 4], I16>,
acc, VectorOfShapeAndType<[4, 4], I64>>,

VectorTypesMatch<lhs, VectorOfShapeAndType<[4, 8], BF16>,
rhs, VectorOfShapeAndType<[8, 4], BF16>,
acc, VectorOfShapeAndType<[4, 4], F32>>]>>;
39 changes: 39 additions & 0 deletions test/dialect/AIEVec/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics

func.func @invalidElementType(%A : vector<4x8xf16>, %B : vector<8x4xf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x8xf16>'}}
%0 = aievec.matmul %A, %B, %C : vector<4x8xf16>, vector<8x4xf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

func.func @invalidShape(%A : vector<4x4xbf16>, %B : vector<4x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
// expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x4xbf16>'}}
%0 = aievec.matmul %A, %B, %C : vector<4x4xbf16>, vector<4x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// -----

func.func @invalidContraction(%A : vector<2x4xi16>, %B : vector<2x8xi16>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
// expected-error @+1 {{op failed to verify that [lhs x rhs = acc] is a valid contraction}}
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<2x8xi16>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

func.func @invalidAccumulatorType(%A : vector<2x4xi16>, %B : vector<4x8xi16>,
%C : vector<2x8xi32>) -> vector<2x8xi32> {
// expected-error @+1 {{op operand #2 must be a vector compatible with an accumulator of matrix-multiply and accumulate, but got 'vector<2x8xi32>'}}
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16>
into vector<2x8xi32>
return %0 : vector<2x8xi32>
}
159 changes: 159 additions & 0 deletions test/dialect/AIEVec/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @matmul_i8i4
// CHECK-SAME: %[[A:.*]]: vector<4x16xi8>
// CHECK-SAME: %[[B:.*]]: vector<16x8xi4>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x16xi8>, vector<16x8xi4> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i8i4(%A : vector<4x16xi8>, %B : vector<16x8xi4>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x16xi8>, vector<16x8xi4>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i8i8
// CHECK-SAME: %[[A:.*]]: vector<4x8xi8>
// CHECK-SAME: %[[B:.*]]: vector<8x8xi8>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xi8>, vector<8x8xi8> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i8i8(%A : vector<4x8xi8>, %B : vector<8x8xi8>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i16i8a
// CHECK-SAME: %[[A:.*]]: vector<4x4xi16>
// CHECK-SAME: %[[B:.*]]: vector<4x8xi8>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x4xi16>, vector<4x8xi8> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i16i8a(%A : vector<4x4xi16>, %B : vector<4x8xi8>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x4xi16>, vector<4x8xi8>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i16i16a
// CHECK-SAME: %[[A:.*]]: vector<4x2xi16>
// CHECK-SAME: %[[B:.*]]: vector<2x8xi16>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x2xi16>, vector<2x8xi16> into vector<4x8xi32>
// CHECK: return %[[RES]] : vector<4x8xi32>
func.func @matmul_i16i16a(%A : vector<4x2xi16>, %B : vector<2x8xi16>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x2xi16>, vector<2x8xi16>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}

// -----

// CHECK-LABEL: @matmul_i16i8b
// CHECK-SAME: %[[A:.*]]: vector<2x8xi16>
// CHECK-SAME: %[[B:.*]]: vector<8x8xi8>
// CHECK-SAME: %[[C:.*]]: vector<2x8xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<2x8xi16>, vector<8x8xi8> into vector<2x8xi64>
// CHECK: return %[[RES]] : vector<2x8xi64>
func.func @matmul_i16i8b(%A : vector<2x8xi16>, %B : vector<8x8xi8>,
%C : vector<2x8xi64>) -> vector<2x8xi64> {
%0 = aievec.matmul %A, %B, %C : vector<2x8xi16>, vector<8x8xi8>
into vector<2x8xi64>
return %0 : vector<2x8xi64>
}

// -----

// CHECK-LABEL: @matmul_i16i8c
// CHECK-SAME: %[[A:.*]]: vector<4x8xi16>
// CHECK-SAME: %[[B:.*]]: vector<8x4xi8>
// CHECK-SAME: %[[C:.*]]: vector<4x4xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xi16>, vector<8x4xi8> into vector<4x4xi64>
// CHECK: return %[[RES]] : vector<4x4xi64>
func.func @matmul_i16i8c(%A : vector<4x8xi16>, %B : vector<8x4xi8>,
%C : vector<4x4xi64>) -> vector<4x4xi64> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xi16>, vector<8x4xi8>
into vector<4x4xi64>
return %0 : vector<4x4xi64>
}

// -----

// CHECK-LABEL: @matmul_i16i16b
// CHECK-SAME: %[[A:.*]]: vector<2x4xi16>
// CHECK-SAME: %[[B:.*]]: vector<4x8xi16>
// CHECK-SAME: %[[C:.*]]: vector<2x8xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<2x4xi16>, vector<4x8xi16> into vector<2x8xi64>
// CHECK: return %[[RES]] : vector<2x8xi64>
func.func @matmul_i16i16b(%A : vector<2x4xi16>, %B : vector<4x8xi16>,
%C : vector<2x8xi64>) -> vector<2x8xi64> {
%0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16>
into vector<2x8xi64>
return %0 : vector<2x8xi64>
}

// -----

// CHECK-LABEL: @matmul_i16i16c
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: vector<4x4xi16>
// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]: vector<4x4xi16>
// CHECK-SAME: %[[C:[a-zA-Z0-9]+]]: vector<4x4xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x4xi16>, vector<4x4xi16> into vector<4x4xi64>
// CHECK: return %[[RES]] : vector<4x4xi64>
func.func @matmul_i16i16c(%A : vector<4x4xi16>, %B : vector<4x4xi16>,
%C : vector<4x4xi64>) -> vector<4x4xi64> {
%0 = aievec.matmul %A, %B, %C : vector<4x4xi16>, vector<4x4xi16>
into vector<4x4xi64>
return %0 : vector<4x4xi64>
}

// -----

// CHECK-LABEL: @matmul_i32i16
// CHECK-SAME: %[[A:.*]]: vector<4x2xi32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xi16>
// CHECK-SAME: %[[C:.*]]: vector<4x4xi64>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x2xi32>, vector<2x4xi16> into vector<4x4xi64>
// CHECK: return %[[RES]] : vector<4x4xi64>
func.func @matmul_i32i16(%A : vector<4x2xi32>, %B : vector<2x4xi16>,
%C : vector<4x4xi64>) -> vector<4x4xi64> {
%0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi16>
into vector<4x4xi64>
return %0 : vector<4x4xi64>
}

// -----

// CHECK-LABEL: @matmul_bf16
// CHECK-SAME: %[[A:.*]]: vector<4x8xbf16>
// CHECK-SAME: %[[B:.*]]: vector<8x4xbf16>
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32>
// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] :
// CHECK-SAME: vector<4x8xbf16>, vector<8x4xbf16> into vector<4x4xf32>
// CHECK: return %[[RES]] : vector<4x4xf32>
func.func @matmul_bf16(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

0 comments on commit 7114fb1

Please sign in to comment.