From 53d01c015d1be901a3fd76f2c29bddd8ce3527c3 Mon Sep 17 00:00:00 2001 From: jamestcl-amd Date: Wed, 22 May 2024 08:11:41 -0700 Subject: [PATCH] [WIP] converstion pattern from aievec.max to vmax intrinsic --- lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp | 93 ++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index 9c1c3d5bd5..3328c7136f 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -1498,6 +1498,98 @@ class BroadcastOpConversion } }; +class MaxOpConversion : public mlir::ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + VectorType resultType = cast(op.getResult().getType()); + Type resultScaTy = resultType.getElementType(); + unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth(); + int resultLanes = getVectorLaneSize(resultType); + int resultVectorSize = resultBitWidth * resultLanes; + + if (resultVectorSize != 512) { + op.emitWarning() << "aievec.max conversion with " << resultVectorSize + << "-bit result is not supported.\n"; + return failure(); + } + + // create xllvm intrinsic + Value maxOp = nullptr; + if (llvm::isa(resultType)) { + // create constant for cmp + auto cmpCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + SmallVector operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst}; + if (resultBitWidth == 8) { + maxOp = rewriter.create( + loc, VectorType::get({64}, rewriter.getI8Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({64}, rewriter.getI8Type()), + VectorType::get({64}, rewriter.getI8Type()), + rewriter.getI32Type()})); + } else if (resultBitWidth == 16) { + maxOp = rewriter.create( + loc, VectorType::get({32}, rewriter.getI16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getI16Type()), + VectorType::get({32}, rewriter.getI16Type()), + rewriter.getI32Type()})); + } else if (resultBitWidth == 32) { + maxOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type()})); + } + } else { + if (resultBitWidth == 16) { + SmallVector operands{adaptor.getLhs(), adaptor.getRhs()}; + SmallVector castedOperands = forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getBF16Type()), + VectorType::get({32}, rewriter.getBF16Type())}); + SmallVector resTypes{ + VectorType::get({32}, rewriter.getBF16Type()), + rewriter.getI32Type()}; + // maxOp = rewriter.create( + // loc, resTypes, + // forceCastOperandsToSignature( + // rewriter, loc, operands, + // {VectorType::get({32}, rewriter.getBF16Type()), + // VectorType::get({32}, rewriter.getBF16Type())})); + // maxOp = rewriter.create( + // loc, resTypes, castedOperands[0], castedOperands[1]); + maxOp = rewriter.create( + loc, VectorType::get({32}, rewriter.getBF16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getBF16Type()), + VectorType::get({32}, rewriter.getBF16Type())})); + } + } + + if (!maxOp) { + op.emitWarning() << "aievec.max conversion is not supported.\n"; + return failure(); + } + + // create truncation op (and bitcast op) + rewriter.replaceOp(op, maxOp); + + return success(); + } +}; + class BroadcastScalarOpConversion : public mlir::ConvertOpToLLVMPattern { public: @@ -1939,6 +2031,7 @@ void populateAIEVecToLLVMConversionPatterns( BroadcastScalarOpConversion, FMAElemOpConversion, MatMulOpConversion, + MaxOpConversion, ShiftOpConversion, ExtractElemOpConversion, FoldAIECastOps>(converter);