Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.min op (#1519)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestcl-amd authored May 30, 2024
1 parent 1b933de commit 5524a8d
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 0 deletions.
45 changes: 45 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,49 @@ def VectorMaxLtBf16IntrOp :
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs)>;

// ----- MIN ELEMENT -----

def VectorMinGe8IntrOp :
AIEVec2_IntrOp<"vmin.ge8",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[64], [I8]>,
VectorOfLengthAndType<[2], [I32]>]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs,
VectorOfLengthAndType<[64], [I8]>:$rhs,
I32:$cmp)> ;

def VectorMinGe16IntrOp :
AIEVec2_IntrOp<"vmin.ge16",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[32], [I16]>,
I32]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$lhs,
VectorOfLengthAndType<[32], [I16]>:$rhs,
I32:$cmp)> ;

def VectorMinGe32IntrOp :
AIEVec2_IntrOp<"vmin.ge32",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[16], [I32]>,
I32]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs,
I32:$cmp)> ;

def VectorMinGeBf16IntrOp :
AIEVec2_IntrOp<"vmin.gebf16",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[32], [BF16]>,
I32]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs)> ;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
98 changes: 98 additions & 0 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,103 @@ class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
}
};

class MinOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
public:
using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

VectorType resultType = cast<VectorType>(op.getResult().getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

// aievec.min op has the AllTypesMatch constraint on lhs/rhs/res
if (resultVectorSize != 512) {
op.emitWarning() << "aievec.min conversion with " << resultVectorSize
<< "-bit result is not supported.\n";
return failure();
}

// create xllvm intrinsic
Value minOp = nullptr;
if (llvm::isa<IntegerType>(resultScaTy)) {
// create constant for cmp
auto cmpCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
if (resultBitWidth == 8) {
minOp = rewriter.create<xllvm::VectorMinGe8IntrOp>(
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({2}, rewriter.getI32Type())}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({64}, rewriter.getI8Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 16) {
minOp = rewriter.create<xllvm::VectorMinGe16IntrOp>(
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI16Type()),
VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
minOp = rewriter.create<xllvm::VectorMinGe32IntrOp>(
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}));
}
} else {
if (resultBitWidth == 16) {
minOp = rewriter.create<xllvm::VectorMinGeBf16IntrOp>(
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type())}));
}
}

if (!minOp) {
// We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a
// possible failure here is due to unsupported element datatype.
op.emitWarning() << "aievec.min conversion fails due to unsupported "
"element data type.\n";
return failure();
}

// create llvm.extractvalue for the first element in the LLVMStruct
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, minOp,
/*position=*/0);

return success();
}
};

class BroadcastScalarOpConversion
: public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
public:
Expand Down Expand Up @@ -2098,6 +2195,7 @@ void populateAIEVecToLLVMConversionPatterns(
FMAElemOpConversion,
MatMulOpConversion,
MaxOpConversion,
MinOpConversion,
ShiftOpConversion,
ExtractElemOpConversion,
FoldAIECastOps>(converter);
Expand Down
80 changes: 80 additions & 0 deletions test/Conversion/AIEVecToLLVM/test-min.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm -verify-diagnostics | FileCheck %s

func.func @i8_min(%arg0 : vector<64xi8>) -> vector<64xi8> {
%0 = aievec.min %arg0, %arg0 : vector<64xi8>
return %0 : vector<64xi8>
}

// CHECK-LABEL: @i8_min
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge8"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) :
// CHECK-SAME: (vector<64xi8>, vector<64xi8>, i32) -> !llvm.struct<(vector<64xi8>, vector<2xi32>)>
// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<64xi8>, vector<2xi32>)>
// CHECK-NEXT: return %[[RES]] : vector<64xi8>

// -----

func.func @i16_min(%arg0 : vector<32xi16>) -> vector<32xi16> {
%0 = aievec.min %arg0, %arg0 : vector<32xi16>
return %0 : vector<32xi16>
}

// CHECK-LABEL: @i16_min
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge16"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) :
// CHECK-SAME: (vector<32xi16>, vector<32xi16>, i32) -> !llvm.struct<(vector<32xi16>, i32)>
// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<32xi16>, i32)>
// CHECK-NEXT: return %[[RES]] : vector<32xi16>

// -----

func.func @i32_min(%arg0 : vector<16xi32>) -> vector<16xi32> {
%0 = aievec.min %arg0, %arg0 : vector<16xi32>
return %0 : vector<16xi32>
}

// CHECK-LABEL: @i32_min
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.ge32"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32) -> !llvm.struct<(vector<16xi32>, i32)>
// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<16xi32>, i32)>
// CHECK-NEXT: return %[[RES]] : vector<16xi32>

// -----

func.func @bf16_min(%arg0 : vector<32xbf16>) -> vector<32xbf16> {
%0 = aievec.min %arg0, %arg0 : vector<32xbf16>
return %0 : vector<32xbf16>
}

// CHECK-LABEL: @bf16_min
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>
// CHECK-NEXT: %[[VMIN:.*]] = "xllvm.intr.aie2.vmin.gebf16"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]]) :
// CHECK-SAME: (vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)>
// CHECK-NEXT: %[[RES:.*]] = llvm.extractvalue %[[VMIN]][0] : !llvm.struct<(vector<32xbf16>, i32)>
// CHECK-NEXT: return %[[RES]] : vector<32xbf16>

// -----

func.func @invalid_i4_min(%arg0 : vector<128xi4>) -> vector<128xi4> {
// expected-warning @+2 {{aievec.min conversion fails due to unsupported element data type.}}
// expected-error @+1 {{failed to legalize operation 'aievec.min' that was explicitly marked illegal}}
%0 = aievec.min %arg0, %arg0 : vector<128xi4>
return %0 : vector<128xi4>
}

// -----

func.func @invalid_i8_min(%arg0 : vector<128xi8>) -> vector<128xi8> {
// expected-warning @+2 {{aievec.min conversion with 1024-bit result is not supported.}}
// expected-error @+1 {{failed to legalize operation 'aievec.min' that was explicitly marked illegal}}
%0 = aievec.min %arg0, %arg0 : vector<128xi8>
return %0 : vector<128xi8>
}
47 changes: 47 additions & 0 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,54 @@ llvm.func @vmax_ltbf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<3
llvm.return %1 : vector<32xbf16>
}

// -- MIN ELEMENT --

// CHECK-LABEL: <64 x i8> @vmin_ge8
llvm.func @vmin_ge8(%lhs: vector<64xi8>, %rhs: vector<64xi8>, %pred: i32) -> vector<64xi8> {
// CHECK: call { <64 x i8>, <2 x i32> } @llvm.aie2.vmin.ge8(
// CHECK-SAME: <64 x i8> %{{[0-9]+}}, <64 x i8> %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmin.ge8"(%lhs, %rhs, %pred) :
(vector<64xi8>, vector<64xi8>, i32) -> !llvm.struct<(vector<64xi8>, vector<2xi32>)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<64xi8>, vector<2xi32>)>
llvm.return %1 : vector<64xi8>
}

// CHECK-LABEL: <32 x i16> @vmin_ge16
llvm.func @vmin_ge16(%lhs: vector<32xi16>, %rhs: vector<32xi16>, %pred: i32) -> vector<32xi16> {
// CHECK: call { <32 x i16>, i32 } @llvm.aie2.vmin.ge16(
// CHECK-SAME: <32 x i16> %{{[0-9]+}}, <32 x i16> %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmin.ge16"(%lhs, %rhs, %pred) :
(vector<32xi16>, vector<32xi16>, i32) -> !llvm.struct<(vector<32xi16>, i32)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xi16>, i32)>
llvm.return %1 : vector<32xi16>
}

// CHECK-LABEL: <16 x i32> @vmin_ge32
llvm.func @vmin_ge32(%lhs: vector<16xi32>, %rhs: vector<16xi32>, %pred: i32) -> vector<16xi32> {
// CHECK: call { <16 x i32>, i32 } @llvm.aie2.vmin.ge32(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmin.ge32"(%lhs, %rhs, %pred) :
(vector<16xi32>, vector<16xi32>, i32) -> !llvm.struct<(vector<16xi32>, i32)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<16xi32>, i32)>
llvm.return %1 : vector<16xi32>
}

// CHECK-LABEL: <32 x bfloat> @vmin_gebf16
llvm.func @vmin_gebf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<32xbf16> {
// CHECK: call { <32 x bfloat>, i32 } @llvm.aie2.vmin.gebf16(
// CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmin.gebf16"(%lhs, %rhs) :
(vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xbf16>, i32)>
llvm.return %1 : vector<32xbf16>
}

// CHECK-LABEL: declare { <64 x i8>, <2 x i32> } @llvm.aie2.vmax.lt8(<64 x i8>, <64 x i8>, i32)
// CHECK-LABEL: declare { <32 x i16>, i32 } @llvm.aie2.vmax.lt16(<32 x i16>, <32 x i16>, i32)
// CHECK-LABEL: declare { <16 x i32>, i32 } @llvm.aie2.vmax.lt32(<16 x i32>, <16 x i32>, i32)
// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(<32 x bfloat>, <32 x bfloat>)

// CHECK-LABEL: declare { <64 x i8>, <2 x i32> } @llvm.aie2.vmin.ge8(<64 x i8>, <64 x i8>, i32)
// CHECK-LABEL: declare { <32 x i16>, i32 } @llvm.aie2.vmin.ge16(<32 x i16>, <32 x i16>, i32)
// CHECK-LABEL: declare { <16 x i32>, i32 } @llvm.aie2.vmin.ge32(<16 x i32>, <16 x i32>, i32)
// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmin.gebf16(<32 x bfloat>, <32 x bfloat>)

0 comments on commit 5524a8d

Please sign in to comment.