Skip to content

Commit

Permalink
Implement math.exp() using lookup tables (#577)
Browse files Browse the repository at this point in the history
  • Loading branch information
linay-xsj authored Aug 14, 2023
1 parent 183675d commit d4f4b2d
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 5 deletions.
3 changes: 2 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "aie/Dialect/AIEVec/IR/AIEVecOps.h"
#include "aie/Dialect/AIEVec/Pipelines/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -2016,7 +2017,7 @@ struct LowerVectorToAIEVec
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, xilinx::aievec::AIEVecDialect,
arith::ArithDialect, memref::MemRefDialect, scf::SCFDialect,
vector::VectorDialect>();
vector::VectorDialect, emitc::EmitCDialect>();
}

Option<std::string> aieTarget{
Expand Down
58 changes: 58 additions & 0 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // ???
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // ???
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" // ???
#include "mlir/Dialect/SCF/IR/SCF.h" // ???
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" // ???
Expand Down Expand Up @@ -213,6 +215,48 @@ struct ConvertSplatTransferReadToBroadcastPattern
}
};

//============================================================================//
//============ AIEML canonicalization conversion patterns ===============//
//============================================================================//

struct ComputeExpOpByLUTPattern : public OpConversionPattern<math::ExpOp> {
using OpConversionPattern<math::ExpOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());

if (!srcType) {
return failure();
}

Type scalarType = srcType.getElementType();
unsigned elWidth = scalarType.getIntOrFloatBitWidth();
unsigned laneSize = getVectorLaneSize(srcType);
if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
return failure();

StringRef includeName = "exp_lut.h";
ModuleOp moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
rewriter.setInsertionPointToStart(
&moduleOp.getRegion().getBlocks().front());
rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);

SmallVector<Value> expOperands = {adaptor.getOperand()};

rewriter.setInsertionPoint(expOp);
Type accType = getVectorOpDestType(srcType, /*AIEML =*/true);
auto funcOp = rewriter.create<emitc::CallOp>(
expOp.getLoc(), TypeRange{accType}, "getExpBf16", nullptr, nullptr,
expOperands);
rewriter.replaceOpWithNewOp<aievec::SRSOp>(expOp, srcType,
funcOp.getResult(0));

return success();
}
};

//============================================================================//
//================ Common AIE canonicalization configuration =================//
//============================================================================//
Expand Down Expand Up @@ -253,6 +297,19 @@ populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns) {
//============================================================================//

static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target) {
target.addLegalDialect<emitc::EmitCDialect>();
target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
VectorType srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
if (!srcType) {
return true;
}
Type scalarType = srcType.getElementType();
unsigned elWidth = scalarType.getIntOrFloatBitWidth();
unsigned laneSize = getVectorLaneSize(srcType);
if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
return true;
return false;
});
target.addDynamicallyLegalOp<vector::TransferReadOp>(
[](vector::TransferReadOp op) {
return !op.getPermutationMap().isConstant() &&
Expand All @@ -262,6 +319,7 @@ static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target) {

static void
populateAIEMLCanonicalizeConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ComputeExpOpByLUTPattern>(patterns.getContext());
patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 128,
1024, 256);
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Targets/AIETargets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -84,6 +85,7 @@ static void registerDialects(DialectRegistry &registry) {
registry.insert<memref::MemRefDialect>();
registry.insert<VectorDialect>();
registry.insert<LLVM::LLVMDialect>();
registry.insert<emitc::EmitCDialect>();
}

// Output the buffer map for the given buffer operations, with the given offset.
Expand Down
8 changes: 4 additions & 4 deletions lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *callOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true)))
return failure();
os << callOp.getCallee();

Expand Down Expand Up @@ -2070,10 +2070,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
raw_ostream &os = emitter.ostream();

os << "#include ";
if (includeOp.getIsStandardIncludeAttrName())
os << "<" << includeOp.getIncludeAttrName() << ">";
if (includeOp.getIsStandardInclude())
os << "<" << includeOp.getInclude() << ">";
else
os << "\"" << includeOp.getIncludeAttrName() << "\"";
os << "\"" << includeOp.getInclude() << "\"";

return success();
}
Expand Down
21 changes: 21 additions & 0 deletions test/unit_tests/aievec_tests/bf16_exp_lut/bf16_exp_lut.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2023, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --convert-vector-to-aievec="aie-target=aieml" -lower-affine | aie-translate -aieml=true --aievec-to-cpp -o dut.cc
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc dut.cc
// RUN: mkdir -p data
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED

module {
func.func @dut(%arg0: memref<1024xbf16>, %arg1: memref<1024xbf16>) {
affine.for %arg3 = 0 to 1024 {
%0 = affine.load %arg0[%arg3] : memref<1024xbf16>
%1 = math.exp %0 : bf16
affine.store %1, %arg1[%arg3] : memref<1024xbf16>
}
return
}
}
3 changes: 3 additions & 0 deletions test/unit_tests/aievec_tests/bf16_exp_lut/defines.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once
constexpr unsigned const IN0_SIZE = 1024;
constexpr unsigned const OUT0_SIZE = 1024;
14 changes: 14 additions & 0 deletions test/unit_tests/aievec_tests/bf16_exp_lut/dut.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "exp_lut.h"
void dut(bfloat16 *restrict v1, bfloat16 *restrict v2) {
size_t v3 = 0;
size_t v4 = 1024;
size_t v5 = 16;
for (size_t v6 = v3; v6 < v4; v6 += v5)
chess_prepare_for_pipelining chess_loop_range(64, 64) {
v16bfloat16 v7 = *(v16bfloat16 *)(v1 + v6);
v16accfloat v8 = getExpBf16(v7);
v16bfloat16 v9 = to_v16bfloat16(v8);
*(v16bfloat16 *)(v2 + v6) = v9;
}
return;
}
Loading

0 comments on commit d4f4b2d

Please sign in to comment.