Skip to content

Commit

Permalink
finalize the conversion pattern from aievec.max to vmax xllvm intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestcl-amd committed May 23, 2024
1 parent b221018 commit f441153
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,30 +1521,42 @@ class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {

// create xllvm intrinsic
Value maxOp = nullptr;
if (llvm::isa<IntegerType>(resultType)) {
if (llvm::isa<IntegerType>(resultScaTy)) {
// create constant for cmp
auto cmpCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
if (resultBitWidth == 8) {
maxOp = rewriter.create<xllvm::VectorMaxLt8IntrOp>(
loc, VectorType::get({64}, rewriter.getI8Type()),
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({64}, rewriter.getI8Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({64}, rewriter.getI8Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 16) {
maxOp = rewriter.create<xllvm::VectorMaxLt16IntrOp>(
loc, VectorType::get({32}, rewriter.getI16Type()),
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI16Type()),
VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
maxOp = rewriter.create<xllvm::VectorMaxLt32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
Expand All @@ -1553,26 +1565,14 @@ class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
}
} else {
if (resultBitWidth == 16) {
SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs()};
SmallVector<Value> castedOperands = forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type())});
SmallVector<Type> resTypes{
VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type()};
// maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
// loc, resTypes,
// forceCastOperandsToSignature(
// rewriter, loc, operands,
// {VectorType::get({32}, rewriter.getBF16Type()),
// VectorType::get({32}, rewriter.getBF16Type())}));
// maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
// loc, resTypes, castedOperands[0], castedOperands[1]);
maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()),
loc,
mlir::LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
{VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type()}),
forceCastOperandsToSignature(
rewriter, loc, operands,
rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type())}));
}
Expand All @@ -1583,8 +1583,9 @@ class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
return failure();
}

// create truncation op (and bitcast op)
rewriter.replaceOp(op, maxOp);
// create llvm.extractvalue for the first element in the LLVMStruct
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, maxOp,
/*position=*/0);

return success();
}
Expand Down

0 comments on commit f441153

Please sign in to comment.