Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNS] Test batch mmt4d #14542

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,10 @@ getDefaultDistributionTileSizes(TilingInterface op) {
}

static bool isPackMatmulLHS(tensor::PackOp op) {
if (op.getSourceRank() == 3 && op.getInnerDimsPos().size() == 2 &&
op.getInnerDimsPos()[0] == 1 && op.getInnerDimsPos()[1] == 2) {
return true;
}
return op.getSourceRank() == 2 && op.getInnerDimsPos().size() == 2 &&
op.getInnerDimsPos()[0] == 0 && op.getInnerDimsPos()[1] == 1;
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "mlir/IR/Types.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
Expand Down Expand Up @@ -232,6 +234,100 @@ struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
}
};

struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {
SetBatchMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<linalg::BatchMatmulOp>(context, benefit) {}

LogicalResult matchAndRewrite(linalg::BatchMatmulOp matmulOp,
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputOperands();
auto outputs = matmulOp.getDpsInitOperands();
auto hasEncoding = [](OpOperand *operand) -> bool {
auto type = llvm::dyn_cast<RankedTensorType>(operand->get().getType());
return type && type.getEncoding();
};
if (llvm::any_of(inputs, hasEncoding) ||
llvm::any_of(outputs, hasEncoding)) {
return failure();
}

Value origLhs = inputs[0]->get();
Value origRhs = inputs[1]->get();
Value origOut = outputs[0]->get();

auto getElemType = [](Value v) -> Type {
if (auto tensorType = llvm::dyn_cast<RankedTensorType>(v.getType())) {
return tensorType.getElementType();
}
return {};
};
Type lhsElemType = getElemType(origLhs);
Type rhsElemType = getElemType(origRhs);
Type outElemType = getElemType(origOut);

if (!lhsElemType || !rhsElemType || !outElemType) {
return failure();
}

LinalgExt::EncodingUser user;

if (lhsElemType.isF32() && rhsElemType.isF32() && outElemType.isF32()) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_F32F32F32;
} else if (lhsElemType.isF16() && rhsElemType.isF16() &&
outElemType.isF32()) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_F16F16F32;
} else if (lhsElemType.isF16() && rhsElemType.isF16() &&
outElemType.isF16()) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_F16F16F16;
} else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
outElemType.isF32()) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_BF16BF16F32;
} else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
outElemType.isBF16()) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_BF16BF16BF16;
} else if (lhsElemType.isSignlessInteger(8) &&
rhsElemType.isSignlessInteger(8) &&
outElemType.isSignlessInteger(32)) {
user = LinalgExt::EncodingUser::BATCH_MATMUL_I8I8I32;
} else {
return rewriter.notifyMatchFailure(
matmulOp,
"unhandled combination of (lhs, rhs, result) element types");
}

Location loc = matmulOp.getLoc();

Value encodedLhs = padAndSetEncoding(rewriter, loc, origLhs, user,
LinalgExt::EncodingRole::LHS);
Value encodedRhs = padAndSetEncoding(rewriter, loc, origRhs, user,
LinalgExt::EncodingRole::RHS);
Value encodedOut = padAndSetEncoding(rewriter, loc, origOut, user,
LinalgExt::EncodingRole::RESULT);

Value matmulTiled = rewriter
.create<linalg::BatchMatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);

// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> origOutSizes =
LinalgExt::getDims(rewriter, loc, origOut);
if (failed(origOutSizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
}

Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled,
origOutSizes.value());

rewriter.replaceOp(matmulOp, result);
return success();
}
};

/// Pattern to fold a `linalg.fill` -> `iree_linalg_ext.set_encoding`
/// operation into a `linalg.fill` of the encoded type.
struct FoldFillWithSetEncoding
Expand Down Expand Up @@ -272,6 +368,7 @@ void SetEncodingPass::runOnOperation() {
{
RewritePatternSet patterns(context);
patterns.insert<SetMatmulEncoding>(context);
patterns.insert<SetBatchMatmulEncoding>(context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
Expand Down
26 changes: 26 additions & 0 deletions e2e_bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
To reproduce benchmarks:

```sh
cd e2e_bench

# Fetch models
./fetch.sh

# Checkout baseline commit: 40794933d45fdbb05d631c9612dc91cc343d1efe
# Build baseline IREE tools (iree-compile, iree-opt, iree-benchmark-module) and
make sure they can be found in PATH.

# Run baseline benchmarks
cd baseline
./bench_baseline.sh
cd ..

# Checkout data-tiling commit 4cc440bc3599207828585f4b51b685a1585fe431
# Build IREE tools with data-tiling changes.

# Run batch_matmul data-tiling benchmarks
cd baseline
cd dt_and_uk
./bench_dt_and_uk.sh
cd ..
```
18 changes: 18 additions & 0 deletions e2e_bench/baseline/bench_baseline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

# The script will find iree tools in PATH. To reproduce baseline benchmarks,
# please build tools at 40794933d45fdbb05d631c9612dc91cc343d1efe.

export MODEL_DIR=..

export IREE_BENCHMARK_MODULE="iree-benchmark-module"
export TRACE_MODE=0

THREADS=1 ../run.sh | tee run1.log
THREADS=4 ../run.sh | tee run4.log
THREADS=8 ../run.sh | tee run8.log

# export IREE_BENCHMARK_MODULE="iree-traced-benchmark-module"
# export TRACE_MODE=1
#
# THREADS=1 ../run.sh | tee traced_run1.log
18 changes: 18 additions & 0 deletions e2e_bench/dt_and_uk/bench_dt_and_uk.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

# The script will find iree tools in PATH. To reproduce data-tiling benchmarks,
# please build tools at 4cc440bc3599207828585f4b51b685a1585fe431

export MODEL_DIR=..

export IREE_BENCHMARK_MODULE="iree-benchmark-module"
export TRACE_MODE=0

THREADS=1 ../run.sh | tee run1.log
THREADS=4 ../run.sh | tee run4.log
THREADS=8 ../run.sh | tee run8.log

# export IREE_BENCHMARK_MODULE="iree-traced-benchmark-module"
# export TRACE_MODE=1
#
# THREADS=1 ../run.sh | tee traced_run1.log
39 changes: 39 additions & 0 deletions e2e_bench/fetch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

# wget -O EfficientNetV2SPT.mlirbc https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230321.784_1679461251/EFFICIENTNET_V2_S/batch_1/linalg.mlir &
# cat<<EOF > EfficientNetV2SPT.mlirbc.run_flag
# --function=forward
# --input=1x3x384x384xf32=0
# EOF

wget -O BertLargeTF_Batch1.mlirbc https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc &
cat<<EOF > BertLargeTF_Batch1.mlirbc.run_flag
--function=serving_default
--input=1x384xi32=0
--input=1x384xi32=0
--input=1x384xi32=0
EOF

wget -O BertLargeTF_Batch32.mlirbc https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH32/stablehlo.mlirbc &
cat<<EOF > BertLargeTF_Batch32.mlirbc.run_flag
--function=forward
--input=32x384xi32=0
--input=32x384xi32=0
--input=32x384xi32=0
EOF

wget -O T5LargeTF_Batch1.mlirbc https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH1/stablehlo.mlirbc &
cat<<EOF > T5LargeTF_Batch1.mlirbc.run_flag
--function=forward
--input=1x512xi32=0
--input=1x512xi32=0
EOF

wget -O T5LargeTF_Batch32.mlirbc https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH32/stablehlo.mlirbc &
cat<<EOF > T5LargeTF_Batch32.mlirbc.run_flag
--function=forward
--input=32x512xi32=0
--input=32x512xi32=0
EOF

wait
75 changes: 75 additions & 0 deletions e2e_bench/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/bin/bash

set -xeuo pipefail

IREE_OPT="$(which ${IREE_OPT:-iree-opt})"
IREE_COMPILE="$(which ${IREE_COMPILER:-iree-compile})"
IREE_BENCHMARK_MODULE="$(which ${IREE_BENCHMARK_MODULE:-iree-benchmark-module})"
IREE_TRACY="$(which ${IREE_TRACY:-iree-tracy-capture})"
TRACE_MODE="${TRACE_MODE:-0}"
THREADS="${THREADS:-1}"
PREFIX="${PREFIX:-}"
MODEL_DIR="${MODEL_DIR:-.}"
COMP_FLAGS="${COMP_FLAGS:-}"

# for MODEL_PATH in $(ls "${MODEL_DIR}/"*.mlirbc); do
for MODEL_PATH in $(ls "${MODEL_DIR}"/BertLargeTF_Batch32.mlirbc); do
MODEL_FILE="$(basename "${MODEL_PATH}")"
echo ">>>> ${MODEL_FILE} <<<<"

"${IREE_COMPILE}" \
"${MODEL_PATH}" \
-o "${PREFIX}${MODEL_FILE}.linalg.mlir" \
--iree-hal-target-backends=llvm-cpu \
--iree-input-type=auto \
--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu \
--iree-llvmcpu-target-cpu=cascadelake \
--iree-flow-enable-data-tiling \
--iree-llvmcpu-enable-microkernels \
--compile-to="preprocessing"

"${IREE_OPT}" --mlir-print-debuginfo "${PREFIX}${MODEL_FILE}.linalg.mlir" > "${PREFIX}${MODEL_FILE}.debug.mlir"

"${IREE_COMPILE}" \
"${PREFIX}${MODEL_FILE}.debug.mlir" \
-o "${PREFIX}${MODEL_FILE}.vmfb" \
${COMP_FLAGS} \
--iree-hal-target-backends=llvm-cpu \
--iree-input-type=auto \
--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu \
--iree-llvmcpu-target-cpu=cascadelake \
--iree-flow-enable-data-tiling \
--iree-llvmcpu-enable-microkernels \
--mlir-print-ir-after=iree-flow-outline-dispatch-regions \
--mlir-elide-elementsattrs-if-larger=4 2> "${PREFIX}${MODEL_FILE}.dump"

if (( THREADS == 1 )); then
declare -a THREAD_ARGS=(
"--device=local-sync"
)
else
declare -a THREAD_ARGS=(
"--device=local-task"
"--task_topology_max_group_count=${THREADS}"
)
fi

RUN_ARGS=($(cat "${MODEL_PATH}.run_flag"))

if (( TRACE_MODE == 1 )); then
"${IREE_TRACY}" -f -o "${PREFIX}${MODEL_FILE}".tracy >/dev/null &
REPETITIONS=1
else
REPETITIONS=5
fi

TRACY_NO_EXIT="${TRACE_MODE}" numactl --cpubind=0 --membind=0 -- \
"${IREE_BENCHMARK_MODULE}" \
--device_allocator=caching \
--benchmark_repetitions="${REPETITIONS}" \
--module=${PREFIX}${MODEL_FILE}.vmfb \
"${THREAD_ARGS[@]}" \
"${RUN_ARGS[@]}"

wait
done
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ static FailureOr<SmallVector<Value>> lowerUpperBoundTileSizeOpToConstants(
results[innerDimsPos[i]] =
rewriter.create<arith::ConstantIndexOp>(loc, tileSize);
}
// For the dims that have no inner tiles, use 1 as tile size to avoid padding.
for (unsigned i = 0; i < results.size(); ++i) {
if (!results[i]) {
results[i] = rewriter.create<arith::ConstantIndexOp>(loc, 1);
}
}
return results;
}

Expand Down
Loading