diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td index 483230bf90..91e238c8d6 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td @@ -14,6 +14,7 @@ #define AIEVEC_OPS include "aie/Dialect/AIEVec/IR/AIEVecTypes.td" +include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Base class for AIE dialect ops. @@ -727,4 +728,43 @@ def AIEVec_ExtElemOp: let assemblyFormat = "$source `,` $index attr-dict `:` type($source) `,` type($index) `,` type($result)"; let hasVerifier = 0; } + +def AIEVec_MatMulOp: + AIEVec_Op<"matmul", [ + Pure, + AllRanksMatch<["lhs", "rhs", "acc"]>, + AllTypesMatch<["acc", "result"]>, + ShapesCompatibleWithContraction<"lhs", "rhs", "acc">, + IsValidAIE2MatMulShapeAndType<"lhs", "rhs", "acc"> + ]>, + Arguments<(ins AIE2MatMulLHS:$lhs, + AIE2MatMulRHS:$rhs, + AIE2MatMulACC:$acc)>, + Results<(outs AIE2MatMulACC:$result)> { + let summary = "AIEML matrix-multiply and accummulate"; + let description = [{ + AMD AIEv2-specific intrinsic that performs a matrix multiplications + between `lhs` and `rhs`, and accumulates the result in `acc`. + + Currently, this intrinsic supports the following type combinations: + + lhs | rhs | Accumulator + :------------------:|:------------------:|:-----------------: + `vector<4x16xi8>` | `vector<16x8xi4>` | `vector<4x8xi32>` + `vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>` + `vector<4x4xi16>` | `vector<4x8xi8>` | `vector<4x8xi32>` + `vector<4x2xi16>` | `vector<2x8xi16>` | `vector<4x8xi32>` + `vector<4x4xi16>` | `vector<4x4xi8>` | `vector<4x4xi32>` + `vector<2x8xi16>` | `vector<8x8xi8>` | `vector<2x8xi64>` + `vector<4x8xi16>` | `vector<8x4xi8>` | `vector<4x4xi64>` + `vector<2x4xi16>` | `vector<4x8xi16>` | `vector<2x8xi64>` + `vector<4x4xi16>` | `vector<4x4xi16>` | `vector<4x4xi64>` + `vector<4x2xi32>` | `vector<2x4xi16>` | `vector<4x4xi64>` + `vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>` + }]; + let assemblyFormat = [{$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` + type($rhs) `into` type($acc)}]; + let hasVerifier = 0; +} + #endif // AIEVEC_OPS diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td b/include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td new file mode 100644 index 0000000000..391604d6e2 --- /dev/null +++ b/include/aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td @@ -0,0 +1,163 @@ +//===- AIEVecTypeConstraints.td - AIEVec type constraints--*- tablegen -*-====// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2023 AMD Inc. +// + +//===----------------------------------------------------------------------===// +// Extra type constraint definitions for AIEVec operations. +//===----------------------------------------------------------------------===// + +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/OpBase.td" + +def I4 : I<4>; + +class TypeShape : + StrFunc<"cast<::mlir::ShapedType>($" # name # ").getShape()">; + +// Notice: This predicate class assumes that the type has been verified to be a +// ShapedType +class VectorOfShape shape> : + CPred.result # " == ArrayRef({" # + !interleave(shape, ", ") # "})">; + +// Notice: This predicate class assumes that the type has been verified to be a +// ShapedType +class VectorOfElementType : + SubstLeaves<"$_self", ElementType<"_self">.result, type.predicate>; + +// Notice: This predicate class assumes that the type has been verified to be a +// ShapedType +class VectorOfShapeAndType shape, Type type> : + Type, + VectorOfElementType]>, + "vector of shape <" # !interleave(shape, "x") # "> and", + "::mlir::VectorType">; + + +// Notice: These predicate definitions assume that the type has been verified to +// be a ShapedType +def AIE2MatMulOperand_4x16xi8 : VectorOfShapeAndType<[4, 16], I8>; +def AIE2MatMulOperand_4x8xi8 : VectorOfShapeAndType<[4, 8], I8>; +def AIE2MatMulOperand_4x4xi16 : VectorOfShapeAndType<[4, 4], I16>; +def AIE2MatMulOperand_4x2xi16 : VectorOfShapeAndType<[4, 2], I16>; +def AIE2MatMulOperand_2x8xi16 : VectorOfShapeAndType<[2, 8], I16>; +def AIE2MatMulOperand_4x8xi16 : VectorOfShapeAndType<[4, 8], I16>; +def AIE2MatMulOperand_2x4xi16 : VectorOfShapeAndType<[2, 4], I16>; +def AIE2MatMulOperand_4x2xi32 : VectorOfShapeAndType<[4, 2], I32>; +def AIE2MatMulOperand_4x8xbf16 : VectorOfShapeAndType<[4, 8], BF16>; + +def AIE2MatMulOperand_16x8xi4 : VectorOfShapeAndType<[16, 8], I4>; +def AIE2MatMulOperand_8x8xi8 : VectorOfShapeAndType<[8, 8], I8>; +def AIE2MatMulOperand_4x4xi8 : VectorOfShapeAndType<[4, 4], I8>; +def AIE2MatMulOperand_8x4xi8 : VectorOfShapeAndType<[8, 4], I8>; +def AIE2MatMulOperand_8x4xbf16 : VectorOfShapeAndType<[8, 4], BF16>; + +def AIE2MatMulOperand_4x8xi32 : VectorOfShapeAndType<[4, 8], I32>; +def AIE2MatMulOperand_4x4xi32 : VectorOfShapeAndType<[4, 4], I32>; +def AIE2MatMulOperand_2x8xi64 : VectorOfShapeAndType<[2, 8], I64>; +def AIE2MatMulOperand_4x4xi64 : VectorOfShapeAndType<[4, 4], I64>; +def AIE2MatMulOperand_4x4xf32 : VectorOfShapeAndType<[4, 4], F32>; + +def AIE2MatMulLHS : + AnyTypeOf<[AIE2MatMulOperand_4x16xi8, + AIE2MatMulOperand_4x8xi8, + AIE2MatMulOperand_4x4xi16, + AIE2MatMulOperand_4x2xi16, + AIE2MatMulOperand_2x8xi16, + AIE2MatMulOperand_4x8xi16, + AIE2MatMulOperand_2x4xi16, + AIE2MatMulOperand_4x2xi32, + AIE2MatMulOperand_4x8xbf16], + "a vector compatible with a lhs operand of matrix-multiply and " + # "accumulate", + "::mlir::VectorType">; + +def AIE2MatMulRHS : + AnyTypeOf<[AIE2MatMulOperand_16x8xi4, + AIE2MatMulOperand_8x8xi8, + AIE2MatMulOperand_4x8xi8, + AIE2MatMulOperand_2x8xi16, + AIE2MatMulOperand_4x4xi8, + AIE2MatMulOperand_8x4xi8, + AIE2MatMulOperand_4x8xi16, + AIE2MatMulOperand_4x4xi16, + AIE2MatMulOperand_2x4xi16, + AIE2MatMulOperand_8x4xbf16], + "a vector compatible with a rhs operand of matrix-multiply and " + # "accumulate", + "::mlir::VectorType">; + +def AIE2MatMulACC : + AnyTypeOf<[AIE2MatMulOperand_4x8xi32, + AIE2MatMulOperand_4x4xi32, + AIE2MatMulOperand_2x8xi64, + AIE2MatMulOperand_4x4xi64, + AIE2MatMulOperand_4x4xf32], + "a vector compatible with an accumulator of matrix-multiply and " + # "accumulate", + "::mlir::VectorType">; + +class ShapeDimsMatch : + CPred.result # "[" # ld # "] == " # + Shape.result # "[" # rd # "]">; + +class ShapesCompatibleWithContraction : + PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc # + "] is a valid contraction", + And<[ShapeDimsMatch, + ShapeDimsMatch, + ShapeDimsMatch]>>; + +class VectorType : StrFunc<"cast($" # name # + ".getType())">; + +class VectorTypesMatch : + And<[SubstLeaves<"$_self", VectorType.result, t1.predicate>, + SubstLeaves<"$_self", VectorType.result, t2.predicate>, + SubstLeaves<"$_self", VectorType.result, t3.predicate>]>; + +class IsValidAIE2MatMulShapeAndType : + PredOpTrait, + VectorTypesMatch, + VectorTypesMatch, + VectorTypesMatch, + VectorTypesMatch, + + VectorTypesMatch, + VectorTypesMatch, + VectorTypesMatch, + VectorTypesMatch, + VectorTypesMatch, + + VectorTypesMatch]>>; diff --git a/test/dialect/AIEVec/invalid.mlir b/test/dialect/AIEVec/invalid.mlir new file mode 100644 index 0000000000..05b274cac2 --- /dev/null +++ b/test/dialect/AIEVec/invalid.mlir @@ -0,0 +1,39 @@ +// RUN: aie-opt %s -split-input-file -verify-diagnostics + +func.func @invalidElementType(%A : vector<4x8xf16>, %B : vector<8x4xf16>, + %C : vector<4x4xf32>) -> vector<4x4xf32> { + // expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x8xf16>'}} + %0 = aievec.matmul %A, %B, %C : vector<4x8xf16>, vector<8x4xf16> + into vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +func.func @invalidShape(%A : vector<4x4xbf16>, %B : vector<4x4xbf16>, + %C : vector<4x4xf32>) -> vector<4x4xf32> { + // expected-error @+1 {{op operand #0 must be a vector compatible with a lhs operand of matrix-multiply and accumulate, but got 'vector<4x4xbf16>'}} + %0 = aievec.matmul %A, %B, %C : vector<4x4xbf16>, vector<4x4xbf16> + into vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +func.func @invalidContraction(%A : vector<2x4xi16>, %B : vector<2x8xi16>, + %C : vector<4x8xi32>) -> vector<4x8xi32> { + // expected-error @+1 {{op failed to verify that [lhs x rhs = acc] is a valid contraction}} + %0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<2x8xi16> + into vector<4x8xi32> + return %0 : vector<4x8xi32> +} + +// ----- + +func.func @invalidAccumulatorType(%A : vector<2x4xi16>, %B : vector<4x8xi16>, + %C : vector<2x8xi32>) -> vector<2x8xi32> { + // expected-error @+1 {{op operand #2 must be a vector compatible with an accumulator of matrix-multiply and accumulate, but got 'vector<2x8xi32>'}} + %0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16> + into vector<2x8xi32> + return %0 : vector<2x8xi32> +} diff --git a/test/dialect/AIEVec/roundtrip.mlir b/test/dialect/AIEVec/roundtrip.mlir new file mode 100644 index 0000000000..28e81ba3cf --- /dev/null +++ b/test/dialect/AIEVec/roundtrip.mlir @@ -0,0 +1,95 @@ +// RUN: aie-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @matmul_bf16 +// CHECK-SAME: %[[A:.*]]: vector<4x8xbf16> +// CHECK-SAME: %[[B:.*]]: vector<8x4xbf16> +// CHECK-SAME: %[[C:.*]]: vector<4x4xf32> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<4x8xbf16>, vector<8x4xbf16> into vector<4x4xf32> +// CHECK: return %[[RES]] : vector<4x4xf32> +func.func @matmul_bf16(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>, + %C : vector<4x4xf32>) -> vector<4x4xf32> { + %0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16> + into vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @matmul_f32 +// CHECK-SAME: %[[A:.*]]: vector<4x8xf32> +// CHECK-SAME: %[[B:.*]]: vector<8x4xf32> +// CHECK-SAME: %[[C:.*]]: vector<4x4xf32> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<4x8xf32>, vector<8x4xf32> into vector<4x4xf32> +// CHECK: return %[[RES]] : vector<4x4xf32> +func.func @matmul_f32(%A : vector<4x8xf32>, %B : vector<8x4xf32>, + %C : vector<4x4xf32>) -> vector<4x4xf32> { + %0 = aievec.matmul %A, %B, %C : vector<4x8xf32>, vector<8x4xf32> + into vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @matmul_i8 +// CHECK-SAME: %[[A:.*]]: vector<4x8xi8> +// CHECK-SAME: %[[B:.*]]: vector<8x8xi8> +// CHECK-SAME: %[[C:.*]]: vector<4x8xi32> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<4x8xi8>, vector<8x8xi8> into vector<4x8xi32> +// CHECK: return %[[RES]] : vector<4x8xi32> +func.func @matmul_i8(%A : vector<4x8xi8>, %B : vector<8x8xi8>, + %C : vector<4x8xi32>) -> vector<4x8xi32> { + %0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8> + into vector<4x8xi32> + return %0 : vector<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @matmul_i16a +// CHECK-SAME: %[[A:.*]]: vector<2x4xi16> +// CHECK-SAME: %[[B:.*]]: vector<4x8xi16> +// CHECK-SAME: %[[C:.*]]: vector<2x8xi64> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<2x4xi16>, vector<4x8xi16> into vector<2x8xi64> +// CHECK: return %[[RES]] : vector<2x8xi64> +func.func @matmul_i16a(%A : vector<2x4xi16>, %B : vector<4x8xi16>, + %C : vector<2x8xi64>) -> vector<2x8xi64> { + %0 = aievec.matmul %A, %B, %C : vector<2x4xi16>, vector<4x8xi16> + into vector<2x8xi64> + return %0 : vector<2x8xi64> +} + +// ----- + +// CHECK-LABEL: @matmul_i16b +// CHECK-SAME: %[[A:.*]]: vector<4x2xi16> +// CHECK-SAME: %[[B:.*]]: vector<2x8xi16> +// CHECK-SAME: %[[C:.*]]: vector<4x8xi32> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<4x2xi16>, vector<2x8xi16> into vector<4x8xi32> +// CHECK: return %[[RES]] : vector<4x8xi32> +func.func @matmul_i16b(%A : vector<4x2xi16>, %B : vector<2x8xi16>, + %C : vector<4x8xi32>) -> vector<4x8xi32> { + %0 = aievec.matmul %A, %B, %C : vector<4x2xi16>, vector<2x8xi16> + into vector<4x8xi32> + return %0 : vector<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @matmul_i32 +// CHECK-SAME: %[[A:.*]]: vector<4x2xi32> +// CHECK-SAME: %[[B:.*]]: vector<2x4xi32> +// CHECK-SAME: %[[C:.*]]: vector<4x4xi64> +// CHECK: %[[RES:.*]] = aievec.matmul %[[A]], %[[B]], %[[C]] : +// CHECK-SAME: vector<4x2xi32>, vector<2x4xi32> into vector<4x4xi64> +// CHECK: return %[[RES]] : vector<4x4xi64> +func.func @matmul_i32(%A : vector<4x2xi32>, %B : vector<2x4xi32>, + %C : vector<4x4xi64>) -> vector<4x4xi64> { + %0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi32> + into vector<4x4xi64> + return %0 : vector<4x4xi64> +}