Skip to content

Commit

Permalink
[WIP] converstion pattern from aievec.max to vmax intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestcl-amd committed May 22, 2024
1 parent ab7d538 commit 53d01c0
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,98 @@ class BroadcastOpConversion
}
};

class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
public:
using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

VectorType resultType = cast<VectorType>(op.getResult().getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

if (resultVectorSize != 512) {
op.emitWarning() << "aievec.max conversion with " << resultVectorSize
<< "-bit result is not supported.\n";
return failure();
}

// create xllvm intrinsic
Value maxOp = nullptr;
if (llvm::isa<IntegerType>(resultType)) {
// create constant for cmp
auto cmpCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
if (resultBitWidth == 8) {
maxOp = rewriter.create<xllvm::VectorMaxLt8IntrOp>(
loc, VectorType::get({64}, rewriter.getI8Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({64}, rewriter.getI8Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 16) {
maxOp = rewriter.create<xllvm::VectorMaxLt16IntrOp>(
loc, VectorType::get({32}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI16Type()),
VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
maxOp = rewriter.create<xllvm::VectorMaxLt32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}));
}
} else {
if (resultBitWidth == 16) {
SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs()};
SmallVector<Value> castedOperands = forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type())});
SmallVector<Type> resTypes{
VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type()};
// maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
// loc, resTypes,
// forceCastOperandsToSignature(
// rewriter, loc, operands,
// {VectorType::get({32}, rewriter.getBF16Type()),
// VectorType::get({32}, rewriter.getBF16Type())}));
// maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
// loc, resTypes, castedOperands[0], castedOperands[1]);
maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type())}));
}
}

if (!maxOp) {
op.emitWarning() << "aievec.max conversion is not supported.\n";
return failure();
}

// create truncation op (and bitcast op)
rewriter.replaceOp(op, maxOp);

return success();
}
};

class BroadcastScalarOpConversion
: public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
public:
Expand Down Expand Up @@ -1939,6 +2031,7 @@ void populateAIEVecToLLVMConversionPatterns(
BroadcastScalarOpConversion,
FMAElemOpConversion,
MatMulOpConversion,
MaxOpConversion,
ShiftOpConversion,
ExtractElemOpConversion,
FoldAIECastOps>(converter);
Expand Down

0 comments on commit 53d01c0

Please sign in to comment.