Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Matmul+Truncf] Enable Matmul+Truncf for shorter shape on Pack-Peel + Objectfifo #822

Merged
merged 13 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// input ${M}x${K}x${TYPE1}
// input ${K}x${N}x${TYPE1}

Abhishek-Varma marked this conversation as resolved.
Show resolved Hide resolved
func.func @matmul_truncf(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}>
{
%cst = arith.constant ${ZERO} : ${TYPE2}
%0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}>
%1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>)
outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%3 = arith.truncf %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}>
return %3: tensor<${M}x${N}x${TYPE1}>
}
25 changes: 25 additions & 0 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,31 @@ def run(self, config):
lower_to_aie_pipeline="objectFifo",
)

# Test(s) of the form matmul(A,B) + truncf(C) where A:MxK, B:KxN and C:MxN
test_name = output_dir / f"test_from_template_matmul_truncf.mlir"
template_name = matmul_template_dir / "matmul_truncf_MxK_KxN.mlir"
generate_matmul_test(test_name, template_name, 8, 8, 8, "bf16", "f32")
identity_mat = np.eye(8, dtype=np.float32)
ones = np.ones(8 * 8, dtype=np.float32).reshape([8, 8])
lhs = ones * 101
rhs = identity_mat * 3
input_args = generate_inputs(test_name, output_dir, 1, {1: lhs, 2: rhs})
aie_vs_baseline(
config,
test_name,
input_args,
ones * 302, # exected output
use_ukernel=False,
tile_pipeline="pack-peel",
lower_to_aie_pipeline="objectFifo",
function_name=None,
seed=1,
rtol=0,
atol=0,
n_repeats=1,
output_type=get_output_type(test_name),
)


class SmokeSet(TestSet):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,52 @@ struct ToMinorIdentityTransferReadPattern
}
};

// clang-format off
/// Pattern to linearize arith.truncf because later aievec.srs in AIEVecToLLVM is
/// expected to have 1-D source and target.
/// Refer: https://github.com/nod-ai/iree-amd-aie/blob/main/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp#L73-L74
///
/// Example of what this pattern achieves :-
/// INPUT
/// %0 = arith.truncf %inp : vector<2x3xf32> to vector<2x3xbf16>
/// OUTPUT
/// %0 = vector.shape_cast %inp : vector<2x3xf32> to vector<6xf32>
/// %1 = arith.truncf %0 : vector<6xf32> to vector<6xbf16>
/// %2 = vector.shape_cast %1 : vector<6xbf16> to vector<2x3xbf16>
// clang-format on
struct FlattenArithTruncFOpPattern : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const override {
// Get old shape type.
auto oldShapedType = dyn_cast<VectorType>(op.getType());
if (!oldShapedType) return failure();
// Bail out if it's already linearized.
if (oldShapedType.getRank() == 1) return failure();
// Linearize the shape.
int64_t linearizedSize = oldShapedType.getNumElements();
// Fetch input.
Value origInputOfTruncFOp = op.getIn();
// Form linearized vector shape type for input and output.
VectorType newVectorTypeForInput = VectorType::get(
{linearizedSize},
cast<ShapedType>(origInputOfTruncFOp.getType()).getElementType());
VectorType newVectorTypeForOutput =
VectorType::get({linearizedSize}, oldShapedType.getElementType());
// Shape cast the original input to linearized shape type.
Value newInputVector = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp);
// Create new base operation with the linearized input/output.
Value newTruncFOp = rewriter.create<arith::TruncFOp>(
op.getLoc(), newVectorTypeForOutput, newInputVector);
// Delinearize the output back to the original type.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
newTruncFOp);
return success();
}
};

// This pattern extracts an implicit transposition of the 2 innermost
// dimensions of `rhs` in a gemm-like contraction op, making it an explicit
// `vector.transpose` op.
Expand Down Expand Up @@ -591,16 +637,16 @@ struct CanonicalizeVectorForAIEVecPass
}
{
RewritePatternSet patterns(context);
patterns.add<ExtractTransposeFromContractionOp,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns
.add<ExtractTransposeFromContractionOp, FlattenArithTruncFOpPattern,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns.add<ConvertSplatTransferReadToBroadcastPattern>(context);
mlir::vector::populateFlattenVectorTransferPatterns(patterns);
mlir::vector::populateVectorBroadcastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}

{
// These must run after 'populateFlattenVectorTransferPatterns' because
// vector.shape_casts are introduced. Merging into a single pass creates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,16 @@ func.func @noncontiguous_write(%v: vector<4x8xi8>) {
vector.transfer_write %v, %alloc[%c0, %c0] : vector<4x8xi8>, memref<4x10xi8>
return
}

// -----

// CHECK-LABEL: @arith_truncf(
// CHECK-SAME: %[[INP:.*]]: vector<2x3xf32>)
func.func @arith_truncf(%inp: vector<2x3xf32>) -> vector<2x3xbf16> {
// CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[INP]] : vector<2x3xf32> to vector<6xf32>
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[LINEARIZE]] : vector<6xf32> to vector<6xbf16>
// CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCF]] : vector<6xbf16> to vector<2x3xbf16>
// CHECK: return %[[DELINEARIZE]]
%0 = arith.truncf %inp : vector<2x3xf32> to vector<2x3xbf16>
return %0 : vector<2x3xbf16>
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility which returns 'true' is the operation needs to be inserted with an
/// `amdaie.core` op.
/// Some ops are surrrounded by scf.for loop nests. Place the entire
/// loop nest inside the amdaie.core op here. Currently look for a
/// subset of ops which we know should be in the core.
/// TODO(newling) improve this design.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should look for a way to improve this.

static bool isCoreComputeOp(Operation *op) {
return isa<linalg::LinalgOp, vector::ContractionOp,
memref::ExtractStridedMetadataOp, func::CallOp, arith::ExtFOp,
arith::TruncFOp, vector::TransferReadOp, vector::TransferWriteOp>(
op);
}

/// Utility to map the parallel mapping attributes to the corresponding
/// induction variables.
void getAttributeMapping(SmallVector<scf::ForallOp> forallOps,
Expand Down Expand Up @@ -128,14 +141,7 @@ LogicalResult insertCoreOps(mlir::ModuleOp moduleOp) {
coreOp.setLinkWith(fnDecl->getAttrOfType<StringAttr>("link_with"));
}

// Some ops are surrrounded by scf.for loop nests. Place the entire
// loop nest inside the amdaie.core op here. Currently look for a
// subset of ops which we know should be in the core.
// TODO(newling) improve this design.
bool insertInCore =
isa<linalg::LinalgOp>(op) || isa<vector::ContractionOp>(op) ||
isa<memref::ExtractStridedMetadataOp>(op) || isa<func::CallOp>(op);
if (insertInCore) {
if (isCoreComputeOp(op)) {
// Most distant ancestor of 'op' that's a strict descendant of
// 'forallOp'.
Operation *ancestor = op;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,31 @@ class AMDAIEInsertLoopsForVectorizationPass
// Return success if the generic op is rewritten, failure otherwise.
LogicalResult maybeRewrite(linalg::GenericOp genericOp,
IRRewriter &rewriter) {
if (isa<linalg::CopyOp, linalg::FillOp>(genericOp)) return failure();

auto iteratorTypes = genericOp.getIteratorTypesArray();
auto numIterators = iteratorTypes.size();

// No outer dimensions to tile if fewer than 4 iterators.
if (numIterators < 4) return failure();

// No outer dimensions to tile if fewer than 3 iterators.
if (numIterators < 3) return failure();

// Enable generating loops for vectorization in case of element-wise ops.
newling marked this conversation as resolved.
Show resolved Hide resolved
// We tile all but the innermost two dimensions currently because they form
// the smallest tiled M x N dimension of the matmul.
if (llvm::all_of(iteratorTypes, [&](mlir::utils::IteratorType iterator) {
return linalg::isParallelIterator(iterator);
})) {
assert(numIterators >= 2 && "expected at least 2 iterators here");
SmallVector<int64_t> tileSizes(numIterators, 1);
tileSizes[numIterators - 2] = 0;
tileSizes[numIterators - 1] = 0;
auto opts = linalg::LinalgTilingOptions().setTileSizes(tileSizes);
auto tiled = linalg::tileLinalgOp(rewriter, genericOp, opts);
const auto &loops = tiled.value().loops;
assert(!loops.empty() && "expected at least one loop here");
rewriter.replaceOp(genericOp, loops[0]->getResult(0));
return success();
}
// Matmul-like ops have 3 operands.
if (genericOp->getNumOperands() != 3) return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,18 @@ void AMDAIEVectorizationPass::runOnOperation() {
// edge case so just disable. See
// https://github.com/nod-ai/iree-amd-aie/issues/594 for more info.
if (isElementwise(cast<linalg::LinalgOp>(op))) {
op->emitRemark() << "not vectorizing linalg elementwise op";
return;
// TODO(avarma): Currently switching vectorization on only for
// arith.truncf. Improve this later by trying to bridge the
// gap between this pass and vector-to-aievec.
for (Operation &innerOps :
cast<linalg::GenericOp>(op).getBody()->getOperations()) {
if (!isa<arith::TruncFOp, linalg::YieldOp>(innerOps)) {
op->emitRemark() << "not vectorizing linalg elementwise op";
return;
}
}
}

// AIE architecture has no vector instructions for 32/64-bit types.
if (!hasOperandWithSmallElementType(op)) return;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ iree_cc_library(
"AMDAIEDmaToCircularDma.cpp"
"AMDAIEDmaUtils.cpp"
"AMDAIEFlattenLogicalObjectFifo.cpp"
"AMDAIEFlattenVectorizedOps.cpp"
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO
#define GEN_PASS_DEF_AMDAIEFLATTENVECTORIZEDOPS
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ std::unique_ptr<Pass> createAMDAIEDmaToCircularDmaPass();
/// Create a pass to flatten the logical objectFifos.
std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass();

/// Create a pass to flatten vectorized ops.
std::unique_ptr<Pass> createAMDAIEFlattenVectorizedOpsPass();

/// Create a pass to fuse the consumer op into the innermost last scf loop.
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,6 @@ def AMDAIEFlattenLogicalObjectFifo :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()";
}

def AMDAIEFlattenVectorizedOps :
Pass<"iree-amdaie-flatten-vectorized-ops", "ModuleOp"> {
let summary = "Flatten vectorized ops.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenVectorizedOpsPass()";
}

def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ iree_lit_test_suite(
"dma_loop_subsumption.mlir"
"dma_to_circular_dma.mlir"
"flatten_logical_objectfifo.mlir"
"flatten_vectorized_ops.mlir"
"fuse_consumer_into_loop_scf_for.mlir"
"fuse_consumer_into_loop_scf_forall.mlir"
"fuse_fill_into_forall.mlir"
Expand Down

This file was deleted.

Loading
Loading