From de62a46a054b178af710e8b9babec2f8489b93ba Mon Sep 17 00:00:00 2001 From: Javier Setoain Date: Thu, 2 May 2024 01:43:17 -0600 Subject: [PATCH] [xllvm] Add support for multi-result intrinsics --- .../aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td | 39 +++++++++++++++++ lib/Dialect/XLLVM/XLLVMOps.cpp | 12 ++++-- test/Target/LLVMIR/aievec.mlir | 42 ++++++++++++------- test/dialect/XLLVM/invalid.mlir | 10 +++++ 4 files changed, 85 insertions(+), 18 deletions(-) create mode 100644 test/dialect/XLLVM/invalid.mlir diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index 500c2b9f52..7109e8c98b 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -251,4 +251,43 @@ def VectorExtractElem32I512IntrOp : I32:$idx, I32:$sign)>; +// ----- MAX ELEMENT ----- +class EnumeratedType { + Type type = ty; + int index = idx; +} + +class EnumerateTypeListFrom tlist, int from = 0> { + list sequence = + !if(!empty(tlist), [], + !listconcat( + [EnumeratedType], + EnumerateTypeListFrom.sequence + )); +} + +class LLVM_StructOf structTypes> : + Type< + And<[LLVM_AnyStruct.predicate, + CPred<"cast<::mlir::LLVM::LLVMStructType>($_self).getBody().size() == " # !size(structTypes)>, + And.sequence, + SubstLeaves<"$_self", + "cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[" # enumTy.index # "]", + enumTy.type.predicate> + ) + > + ]>, + "an LLVM struct of {" # !interleave(!foreach(ty, structTypes, ty.summary), "; ") # "}" + >; + +def VectorMaxLtBf16IntrOp : + AIEVec2_IntrOp<"vmax.ltbf16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [BF16]>, + I32]> + >], /*numResults=*/2>, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, + VectorOfLengthAndType<[32], [BF16]>:$rhs)>; + #endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD diff --git a/lib/Dialect/XLLVM/XLLVMOps.cpp b/lib/Dialect/XLLVM/XLLVMOps.cpp index 31356dcb81..5da545bdcb 100644 --- a/lib/Dialect/XLLVM/XLLVMOps.cpp +++ b/lib/Dialect/XLLVM/XLLVMOps.cpp @@ -46,12 +46,16 @@ getNamedIntrinsicDeclaration(llvm::Module *M, llvm::StringRef fullName, llvm::CallInst *createExternalLLVMIntrinsicCall( llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Operation *intrOp, llvm::StringRef intrinsicName) { - // We support 0 or 1 results - assert(intrOp->getNumResults() <= 1 && - "external multi-result intrinsics not supported"); llvm::Type *resTy = nullptr; - if (intrOp->getNumResults()) + unsigned numResults = intrOp->getNumResults(); + if (numResults == 1) resTy = moduleTranslation.convertType(*(intrOp->getResultTypes().begin())); + else if (numResults > 1) { + SmallVector resTys; + for (auto ty : intrOp->getResultTypes()) + resTys.push_back(moduleTranslation.convertType(ty)); + resTy = llvm::StructType::get(builder.getContext(), resTys); + } auto operands = moduleTranslation.lookupValues(intrOp->getOperands()); SmallVector types; for (auto op : operands) diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index 875c07d0ac..b3c5c6b77d 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -1,6 +1,6 @@ -// RUN: aie-translate %s -mlir-to-llvmir | FileCheck %s +// RUN: aie-translate %s -mlir-to-llvmir -split-input-file | FileCheck %s -// ----- MAC ----- +// -- MAC -- // CHECK-LABEL: define <16 x i64> @mac_conf_acc32 llvm.func @mac_conf_acc32(%A : vector<64xi8>, @@ -30,7 +30,7 @@ llvm.func @mac_conf_bf16(%A : vector<32xbf16>, llvm.return %0 : vector<8xi64> } -// ----- MSC ----- +// -- MSC -- // CHECK-LABEL: define <8 x i64> @msc_conf_bf16 llvm.func @msc_conf_bf16(%A : vector<32xbf16>, @@ -46,7 +46,7 @@ llvm.func @msc_conf_bf16(%A : vector<32xbf16>, llvm.return %0 : vector<8xi64> } -// ----- MUL ----- +// -- MUL -- // CHECK-LABEL: define <16 x i64> @mul_conf_acc32 llvm.func @mul_conf_acc32(%A : vector<64xi8>, @@ -87,7 +87,7 @@ llvm.func @mul_conf_bf16(%A : vector<32xbf16>, llvm.return %0 : vector<8xi64> } -// ----- SET ----- +// -- SET -- // CHECK-LABEL: define <16 x i32> @vector_set_128b_into_512b llvm.func @vector_set_128b_into_512b(%v : vector<4xi32>) -> vector<16xi32> { @@ -105,7 +105,7 @@ llvm.func @vector_set_256b_into_512b(%v : vector<8xi32>) -> vector<16xi32> { llvm.return %1 : vector<16xi32> } -// ----- SRS ----- +// -- SRS -- // CHECK-LABEL: define <32 x i16> @srs_512b_v32_acc32 llvm.func @srs_512b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<32xi16> { @@ -142,7 +142,7 @@ llvm.func @srs_256b_v16_accfloat(%v : vector<8xi64>) -> vector<16xbf16> { llvm.return %0 : vector<16xbf16> } -// ----- BROADCAST ----- +// -- BROADCAST -- // CHECK-LABEL: define <64 x i8> @vbroadcast8_i512 llvm.func @vbroadcast8_i512(%val : i32) -> vector<64xi8> { @@ -184,7 +184,7 @@ llvm.func @vbroadcastfloat_i512(%val : f32) -> vector<16xf32> { llvm.return %0 : vector<16xf32> } -// ----- EXT ----- +// -- EXT -- // CHECK-LABEL: define <8 x i32> @ext_i256_i512 llvm.func @ext_i256_i512(%v : vector<16xi32>, %idx : i32) -> vector<8xi32> { @@ -195,7 +195,7 @@ llvm.func @ext_i256_i512(%v : vector<16xi32>, %idx : i32) -> vector<8xi32> { llvm.return %1 : vector<8xi32> } -// ----- CONCAT ----- +// -- CONCAT -- // CHECK-LABEL: define <16 x i32> @concat_i512_i256 llvm.func @concat_i512_i256(%a : vector<8xi32>, %b : vector<8xi32>) -> vector<16xi32> { @@ -226,7 +226,7 @@ llvm.func @concat_i1024_i512(%a : vector<16xi32>, %b : vector<16xi32>) -> vector llvm.return %0 : vector<32xi32> } -// ----- SHUFFLE ----- +// -- SHUFFLE -- // CHECK-LABEL: define <16 x i32> @shuffle_i512 llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -> vector<16xi32> { @@ -237,7 +237,7 @@ llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) - llvm.return %0 : vector<16xi32> } -// ----- UNDEF ----- +// -- UNDEF -- // CHECK-LABEL: define <16 x i32> @undef_v16i32 llvm.func @undef_v16i32() -> vector<16xi32> { @@ -246,7 +246,7 @@ llvm.func @undef_v16i32() -> vector<16xi32> { llvm.return %0 : vector<16xi32> } -// ----- UPD ----- +// -- UPD -- // CHECK-LABEL: define <32 x bfloat> @upd_bf512_bf256 llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i32) -> vector<32xbf16> { @@ -256,7 +256,7 @@ llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i3 llvm.return %0 : vector<32xbf16> } -// ----- SHIFT ----- +// -- SHIFT -- // CHECK-LABEL: define <16 x i32> @vshift_i512_i512 llvm.func @vshift_i512_i512(%a : vector<16xi32>, %b : vector<16xi32>, %step : i32, %shift : i32) -> vector<16xi32> { @@ -276,7 +276,7 @@ llvm.func @vshift_bf512_bf512(%a : vector<32xbf16>, %b : vector<32xbf16>, %step llvm.return %0 : vector<32xbf16> } -// ----- EXTRACT ELEMENT ----- +// -- EXTRACT ELEMENT -- // CHECK-LABEL: define i32 @vextract_elem8_i512 llvm.func @vextract_elem8_i512(%a : vector<64xi8>, %idx : i32, %sign : i32) -> i32 { @@ -301,3 +301,17 @@ llvm.func @vextract_elem32_i512(%a : vector<16xi32>, %idx : i32, %sign : i32) -> %0 = "xllvm.intr.aie2.vextract.elem32.I512"(%a, %idx, %sign) : (vector<16xi32>, i32, i32) -> i32 llvm.return %0 : i32 } + +// ----- + +// CHECK-LABEL: <32 x bfloat> @vmax_ltbf16 +llvm.func @vmax_ltbf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<32xbf16> { + // CHECK: call { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16( + // CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vmax.ltbf16"(%lhs, %rhs) : + (vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xbf16>, i32)> + llvm.return %1 : vector<32xbf16> +} + +// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(<32 x bfloat>, <32 x bfloat>) \ No newline at end of file diff --git a/test/dialect/XLLVM/invalid.mlir b/test/dialect/XLLVM/invalid.mlir new file mode 100644 index 0000000000..7a146148f1 --- /dev/null +++ b/test/dialect/XLLVM/invalid.mlir @@ -0,0 +1,10 @@ +// RUN: aie-opt %s -split-input-file -verify-diagnostics + +func.func @invalidStructType(%A : vector<32xbf16>, %B : vector<32xbf16>) + -> vector<16xbf16> { + // expected-error @+1 {{'res' is an LLVM struct of {vector of bfloat16 type values of length 32; 32-bit signless integer}} + %rs = "xllvm.intr.aie2.vmax.ltbf16"(%A, %B) : + (vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<16xbf16>, i32)> + %rv = llvm.extractvalue %rs[0] : !llvm.struct<(vector<16xbf16>, i32)> + return %rv : vector<16xbf16> +} \ No newline at end of file