Skip to content

Commit

Permalink
[xllvm] Add support for multi-result intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain committed May 8, 2024
1 parent a9037bc commit de62a46
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 18 deletions.
39 changes: 39 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,43 @@ def VectorExtractElem32I512IntrOp :
I32:$idx,
I32:$sign)>;

// ----- MAX ELEMENT -----
class EnumeratedType<Type ty, int idx> {
Type type = ty;
int index = idx;
}

class EnumerateTypeListFrom<list<Type> tlist, int from = 0> {
list<EnumeratedType> sequence =
!if(!empty(tlist), [],
!listconcat(
[EnumeratedType<!head(tlist), from>],
EnumerateTypeListFrom<!tail(tlist), !add(from, 1)>.sequence
));
}

class LLVM_StructOf<list<Type> structTypes> :
Type<
And<[LLVM_AnyStruct.predicate,
CPred<"cast<::mlir::LLVM::LLVMStructType>($_self).getBody().size() == " # !size(structTypes)>,
And<!foreach(enumTy, EnumerateTypeListFrom<structTypes>.sequence,
SubstLeaves<"$_self",
"cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[" # enumTy.index # "]",
enumTy.type.predicate>
)
>
]>,
"an LLVM struct of {" # !interleave(!foreach(ty, structTypes, ty.summary), "; ") # "}"
>;

def VectorMaxLtBf16IntrOp :
AIEVec2_IntrOp<"vmax.ltbf16",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[32], [BF16]>,
I32]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs)>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
12 changes: 8 additions & 4 deletions lib/Dialect/XLLVM/XLLVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ getNamedIntrinsicDeclaration(llvm::Module *M, llvm::StringRef fullName,
llvm::CallInst *createExternalLLVMIntrinsicCall(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::StringRef intrinsicName) {
// We support 0 or 1 results
assert(intrOp->getNumResults() <= 1 &&
"external multi-result intrinsics not supported");
llvm::Type *resTy = nullptr;
if (intrOp->getNumResults())
unsigned numResults = intrOp->getNumResults();
if (numResults == 1)
resTy = moduleTranslation.convertType(*(intrOp->getResultTypes().begin()));
else if (numResults > 1) {
SmallVector<llvm::Type *> resTys;
for (auto ty : intrOp->getResultTypes())
resTys.push_back(moduleTranslation.convertType(ty));
resTy = llvm::StructType::get(builder.getContext(), resTys);
}
auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
SmallVector<llvm::Type *> types;
for (auto op : operands)
Expand Down
42 changes: 28 additions & 14 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: aie-translate %s -mlir-to-llvmir | FileCheck %s
// RUN: aie-translate %s -mlir-to-llvmir -split-input-file | FileCheck %s

// ----- MAC -----
// -- MAC --

// CHECK-LABEL: define <16 x i64> @mac_conf_acc32
llvm.func @mac_conf_acc32(%A : vector<64xi8>,
Expand Down Expand Up @@ -30,7 +30,7 @@ llvm.func @mac_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- MSC -----
// -- MSC --

// CHECK-LABEL: define <8 x i64> @msc_conf_bf16
llvm.func @msc_conf_bf16(%A : vector<32xbf16>,
Expand All @@ -46,7 +46,7 @@ llvm.func @msc_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- MUL -----
// -- MUL --

// CHECK-LABEL: define <16 x i64> @mul_conf_acc32
llvm.func @mul_conf_acc32(%A : vector<64xi8>,
Expand Down Expand Up @@ -87,7 +87,7 @@ llvm.func @mul_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- SET -----
// -- SET --

// CHECK-LABEL: define <16 x i32> @vector_set_128b_into_512b
llvm.func @vector_set_128b_into_512b(%v : vector<4xi32>) -> vector<16xi32> {
Expand All @@ -105,7 +105,7 @@ llvm.func @vector_set_256b_into_512b(%v : vector<8xi32>) -> vector<16xi32> {
llvm.return %1 : vector<16xi32>
}

// ----- SRS -----
// -- SRS --

// CHECK-LABEL: define <32 x i16> @srs_512b_v32_acc32
llvm.func @srs_512b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<32xi16> {
Expand Down Expand Up @@ -142,7 +142,7 @@ llvm.func @srs_256b_v16_accfloat(%v : vector<8xi64>) -> vector<16xbf16> {
llvm.return %0 : vector<16xbf16>
}

// ----- BROADCAST -----
// -- BROADCAST --

// CHECK-LABEL: define <64 x i8> @vbroadcast8_i512
llvm.func @vbroadcast8_i512(%val : i32) -> vector<64xi8> {
Expand Down Expand Up @@ -184,7 +184,7 @@ llvm.func @vbroadcastfloat_i512(%val : f32) -> vector<16xf32> {
llvm.return %0 : vector<16xf32>
}

// ----- EXT -----
// -- EXT --

// CHECK-LABEL: define <8 x i32> @ext_i256_i512
llvm.func @ext_i256_i512(%v : vector<16xi32>, %idx : i32) -> vector<8xi32> {
Expand All @@ -195,7 +195,7 @@ llvm.func @ext_i256_i512(%v : vector<16xi32>, %idx : i32) -> vector<8xi32> {
llvm.return %1 : vector<8xi32>
}

// ----- CONCAT -----
// -- CONCAT --

// CHECK-LABEL: define <16 x i32> @concat_i512_i256
llvm.func @concat_i512_i256(%a : vector<8xi32>, %b : vector<8xi32>) -> vector<16xi32> {
Expand Down Expand Up @@ -226,7 +226,7 @@ llvm.func @concat_i1024_i512(%a : vector<16xi32>, %b : vector<16xi32>) -> vector
llvm.return %0 : vector<32xi32>
}

// ----- SHUFFLE -----
// -- SHUFFLE --

// CHECK-LABEL: define <16 x i32> @shuffle_i512
llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -> vector<16xi32> {
Expand All @@ -237,7 +237,7 @@ llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -
llvm.return %0 : vector<16xi32>
}

// ----- UNDEF -----
// -- UNDEF --

// CHECK-LABEL: define <16 x i32> @undef_v16i32
llvm.func @undef_v16i32() -> vector<16xi32> {
Expand All @@ -246,7 +246,7 @@ llvm.func @undef_v16i32() -> vector<16xi32> {
llvm.return %0 : vector<16xi32>
}

// ----- UPD -----
// -- UPD --

// CHECK-LABEL: define <32 x bfloat> @upd_bf512_bf256
llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i32) -> vector<32xbf16> {
Expand All @@ -256,7 +256,7 @@ llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i3
llvm.return %0 : vector<32xbf16>
}

// ----- SHIFT -----
// -- SHIFT --

// CHECK-LABEL: define <16 x i32> @vshift_i512_i512
llvm.func @vshift_i512_i512(%a : vector<16xi32>, %b : vector<16xi32>, %step : i32, %shift : i32) -> vector<16xi32> {
Expand All @@ -276,7 +276,7 @@ llvm.func @vshift_bf512_bf512(%a : vector<32xbf16>, %b : vector<32xbf16>, %step
llvm.return %0 : vector<32xbf16>
}

// ----- EXTRACT ELEMENT -----
// -- EXTRACT ELEMENT --

// CHECK-LABEL: define i32 @vextract_elem8_i512
llvm.func @vextract_elem8_i512(%a : vector<64xi8>, %idx : i32, %sign : i32) -> i32 {
Expand All @@ -301,3 +301,17 @@ llvm.func @vextract_elem32_i512(%a : vector<16xi32>, %idx : i32, %sign : i32) ->
%0 = "xllvm.intr.aie2.vextract.elem32.I512"(%a, %idx, %sign) : (vector<16xi32>, i32, i32) -> i32
llvm.return %0 : i32
}

// -----

// CHECK-LABEL: <32 x bfloat> @vmax_ltbf16
llvm.func @vmax_ltbf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<32xbf16> {
// CHECK: call { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(
// CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmax.ltbf16"(%lhs, %rhs) :
(vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xbf16>, i32)>
llvm.return %1 : vector<32xbf16>
}

// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(<32 x bfloat>, <32 x bfloat>)
10 changes: 10 additions & 0 deletions test/dialect/XLLVM/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics

func.func @invalidStructType(%A : vector<32xbf16>, %B : vector<32xbf16>)
-> vector<16xbf16> {
// expected-error @+1 {{'res' is an LLVM struct of {vector of bfloat16 type values of length 32; 32-bit signless integer}}
%rs = "xllvm.intr.aie2.vmax.ltbf16"(%A, %B) :
(vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<16xbf16>, i32)>
%rv = llvm.extractvalue %rs[0] : !llvm.struct<(vector<16xbf16>, i32)>
return %rv : vector<16xbf16>
}

0 comments on commit de62a46

Please sign in to comment.