From 5524a8d9c5740ff039b6bd0b80a982aee0bc6db5 Mon Sep 17 00:00:00 2001 From: James Lin Date: Thu, 30 May 2024 13:40:10 -0500 Subject: [PATCH] [aievec] to-llvm flow for aievec.min op (#1519) --- .../aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td | 45 +++++++++ lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp | 98 +++++++++++++++++++ test/Conversion/AIEVecToLLVM/test-min.mlir | 80 +++++++++++++++ test/Target/LLVMIR/aievec.mlir | 47 +++++++++ 4 files changed, 270 insertions(+) create mode 100644 test/Conversion/AIEVecToLLVM/test-min.mlir diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index f37d949900..7d37e5c714 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -384,4 +384,49 @@ def VectorMaxLtBf16IntrOp : Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, VectorOfLengthAndType<[32], [BF16]>:$rhs)>; +// ----- MIN ELEMENT ----- + +def VectorMinGe8IntrOp : + AIEVec2_IntrOp<"vmin.ge8", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[64], [I8]>, + VectorOfLengthAndType<[2], [I32]>]> + >], /*numResults=*/2>, + Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, + VectorOfLengthAndType<[64], [I8]>:$rhs, + I32:$cmp)> ; + +def VectorMinGe16IntrOp : + AIEVec2_IntrOp<"vmin.ge16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [I16]>, + I32]> + >], /*numResults=*/2>, + Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$lhs, + VectorOfLengthAndType<[32], [I16]>:$rhs, + I32:$cmp)> ; + +def VectorMinGe32IntrOp : + AIEVec2_IntrOp<"vmin.ge32", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[16], [I32]>, + I32]> + >], /*numResults=*/2>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, + VectorOfLengthAndType<[16], [I32]>:$rhs, + I32:$cmp)> ; + +def VectorMinGeBf16IntrOp : + AIEVec2_IntrOp<"vmin.gebf16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [BF16]>, + I32]> + >], /*numResults=*/2>, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, + VectorOfLengthAndType<[32], [BF16]>:$rhs)> ; + #endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index a8a774deec..c3069c2732 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -1656,6 +1656,103 @@ class MaxOpConversion : public mlir::ConvertOpToLLVMPattern { } }; +class MinOpConversion : public mlir::ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + VectorType resultType = cast(op.getResult().getType()); + Type resultScaTy = resultType.getElementType(); + unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth(); + int resultLanes = getVectorLaneSize(resultType); + int resultVectorSize = resultBitWidth * resultLanes; + + // aievec.min op has the AllTypesMatch constraint on lhs/rhs/res + if (resultVectorSize != 512) { + op.emitWarning() << "aievec.min conversion with " << resultVectorSize + << "-bit result is not supported.\n"; + return failure(); + } + + // create xllvm intrinsic + Value minOp = nullptr; + if (llvm::isa(resultScaTy)) { + // create constant for cmp + auto cmpCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + SmallVector operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst}; + if (resultBitWidth == 8) { + minOp = rewriter.create( + loc, + mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + {VectorType::get({64}, rewriter.getI8Type()), + VectorType::get({2}, rewriter.getI32Type())}), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({64}, rewriter.getI8Type()), + VectorType::get({64}, rewriter.getI8Type()), + rewriter.getI32Type()})); + } else if (resultBitWidth == 16) { + minOp = rewriter.create( + loc, + mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + {VectorType::get({32}, rewriter.getI16Type()), + rewriter.getI32Type()}), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getI16Type()), + VectorType::get({32}, rewriter.getI16Type()), + rewriter.getI32Type()})); + } else if (resultBitWidth == 32) { + minOp = rewriter.create( + loc, + mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + {VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type()}), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type()})); + } + } else { + if (resultBitWidth == 16) { + minOp = rewriter.create( + loc, + mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + {VectorType::get({32}, rewriter.getBF16Type()), + rewriter.getI32Type()}), + forceCastOperandsToSignature( + rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()}, + {VectorType::get({32}, rewriter.getBF16Type()), + VectorType::get({32}, rewriter.getBF16Type())})); + } + } + + if (!minOp) { + // We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a + // possible failure here is due to unsupported element datatype. + op.emitWarning() << "aievec.min conversion fails due to unsupported " + "element data type.\n"; + return failure(); + } + + // create llvm.extractvalue for the first element in the LLVMStruct + rewriter.replaceOpWithNewOp(op, minOp, + /*position=*/0); + + return success(); + } +}; + class BroadcastScalarOpConversion : public mlir::ConvertOpToLLVMPattern { public: @@ -2098,6 +2195,7 @@ void populateAIEVecToLLVMConversionPatterns( FMAElemOpConversion, MatMulOpConversion, MaxOpConversion, + MinOpConversion, ShiftOpConversion, ExtractElemOpConversion, FoldAIECastOps>(converter); diff --git a/test/Conversion/AIEVecToLLVM/test-min.mlir b/test/Conversion/AIEVecToLLVM/test-min.mlir new file mode 100644 index 0000000000..4930d639c9 --- /dev/null +++ b/test/Conversion/AIEVecToLLVM/test-min.mlir @@ -0,0 +1,80 @@ +// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm -verify-diagnostics | FileCheck %s + +func.func @i8_min(%arg0 : vector<64xi8>) -> vector<64xi8> { + %0 = aievec.min %arg0, %arg0 : vector<64xi8> + return %0 : vector<64xi8> +} + +// CHECK-LABEL: @i8_min +// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8> +// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge8"( +// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) : +// CHECK-SAME: (vector<64xi8>, vector<64xi8>, i32) -> !llvm.struct<(vector<64xi8>, vector<2xi32>)> +// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<64xi8>, vector<2xi32>)> +// CHECK-NEXT: return %[[RES]] : vector<64xi8> + +// ----- + +func.func @i16_min(%arg0 : vector<32xi16>) -> vector<32xi16> { + %0 = aievec.min %arg0, %arg0 : vector<32xi16> + return %0 : vector<32xi16> +} + +// CHECK-LABEL: @i16_min +// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16> +// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge16"( +// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) : +// CHECK-SAME: (vector<32xi16>, vector<32xi16>, i32) -> !llvm.struct<(vector<32xi16>, i32)> +// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<32xi16>, i32)> +// CHECK-NEXT: return %[[RES]] : vector<32xi16> + +// ----- + +func.func @i32_min(%arg0 : vector<16xi32>) -> vector<16xi32> { + %0 = aievec.min %arg0, %arg0 : vector<16xi32> + return %0 : vector<16xi32> +} + +// CHECK-LABEL: @i32_min +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32> +// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge32"( +// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32) -> !llvm.struct<(vector<16xi32>, i32)> +// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<16xi32>, i32)> +// CHECK-NEXT: return %[[RES]] : vector<16xi32> + +// ----- + +func.func @bf16_min(%arg0 : vector<32xbf16>) -> vector<32xbf16> { + %0 = aievec.min %arg0, %arg0 : vector<32xbf16> + return %0 : vector<32xbf16> +} + +// CHECK-LABEL: @bf16_min +// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16> +// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.gebf16"( +// CHECK-SAME: %[[ARG0]], %[[ARG0]]) : +// CHECK-SAME: (vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)> +// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<32xbf16>, i32)> +// CHECK-NEXT: return %[[RES]] : vector<32xbf16> + +// ----- + +func.func @invalid_i4_min(%arg0 : vector<128xi4>) -> vector<128xi4> { + // expected-warning @+2 {{aievec.min conversion fails due to unsupported element data type.}} + // expected-error @+1 {{failed to legalize operation 'aievec.min' that was explicitly marked illegal}} + %0 = aievec.min %arg0, %arg0 : vector<128xi4> + return %0 : vector<128xi4> +} + +// ----- + +func.func @invalid_i8_min(%arg0 : vector<128xi8>) -> vector<128xi8> { + // expected-warning @+2 {{aievec.min conversion with 1024-bit result is not supported.}} + // expected-error @+1 {{failed to legalize operation 'aievec.min' that was explicitly marked illegal}} + %0 = aievec.min %arg0, %arg0 : vector<128xi8> + return %0 : vector<128xi8> +} diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index 6b56d37dd5..304c8b7b69 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -462,7 +462,54 @@ llvm.func @vmax_ltbf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<3 llvm.return %1 : vector<32xbf16> } +// -- MIN ELEMENT -- + +// CHECK-LABEL: <64 x i8> @vmin_ge8 +llvm.func @vmin_ge8(%lhs: vector<64xi8>, %rhs: vector<64xi8>, %pred: i32) -> vector<64xi8> { + // CHECK: call { <64 x i8>, <2 x i32> } @llvm.aie2.vmin.ge8( + // CHECK-SAME: <64 x i8> %{{[0-9]+}}, <64 x i8> %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vmin.ge8"(%lhs, %rhs, %pred) : + (vector<64xi8>, vector<64xi8>, i32) -> !llvm.struct<(vector<64xi8>, vector<2xi32>)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<64xi8>, vector<2xi32>)> + llvm.return %1 : vector<64xi8> +} + +// CHECK-LABEL: <32 x i16> @vmin_ge16 +llvm.func @vmin_ge16(%lhs: vector<32xi16>, %rhs: vector<32xi16>, %pred: i32) -> vector<32xi16> { + // CHECK: call { <32 x i16>, i32 } @llvm.aie2.vmin.ge16( + // CHECK-SAME: <32 x i16> %{{[0-9]+}}, <32 x i16> %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vmin.ge16"(%lhs, %rhs, %pred) : + (vector<32xi16>, vector<32xi16>, i32) -> !llvm.struct<(vector<32xi16>, i32)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xi16>, i32)> + llvm.return %1 : vector<32xi16> +} + +// CHECK-LABEL: <16 x i32> @vmin_ge32 +llvm.func @vmin_ge32(%lhs: vector<16xi32>, %rhs: vector<16xi32>, %pred: i32) -> vector<16xi32> { + // CHECK: call { <16 x i32>, i32 } @llvm.aie2.vmin.ge32( + // CHECK-SAME: <16 x i32> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vmin.ge32"(%lhs, %rhs, %pred) : + (vector<16xi32>, vector<16xi32>, i32) -> !llvm.struct<(vector<16xi32>, i32)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<16xi32>, i32)> + llvm.return %1 : vector<16xi32> +} + +// CHECK-LABEL: <32 x bfloat> @vmin_gebf16 +llvm.func @vmin_gebf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<32xbf16> { + // CHECK: call { <32 x bfloat>, i32 } @llvm.aie2.vmin.gebf16( + // CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vmin.gebf16"(%lhs, %rhs) : + (vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xbf16>, i32)> + llvm.return %1 : vector<32xbf16> +} + // CHECK-LABEL: declare { <64 x i8>, <2 x i32> } @llvm.aie2.vmax.lt8(<64 x i8>, <64 x i8>, i32) // CHECK-LABEL: declare { <32 x i16>, i32 } @llvm.aie2.vmax.lt16(<32 x i16>, <32 x i16>, i32) // CHECK-LABEL: declare { <16 x i32>, i32 } @llvm.aie2.vmax.lt32(<16 x i32>, <16 x i32>, i32) // CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(<32 x bfloat>, <32 x bfloat>) + +// CHECK-LABEL: declare { <64 x i8>, <2 x i32> } @llvm.aie2.vmin.ge8(<64 x i8>, <64 x i8>, i32) +// CHECK-LABEL: declare { <32 x i16>, i32 } @llvm.aie2.vmin.ge16(<32 x i16>, <32 x i16>, i32) +// CHECK-LABEL: declare { <16 x i32>, i32 } @llvm.aie2.vmin.ge32(<16 x i32>, <16 x i32>, i32) +// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmin.gebf16(<32 x bfloat>, <32 x bfloat>)