diff --git a/aie_runtime_lib/AIE/tanh.cpp b/aie_runtime_lib/AIE/tanh.cpp new file mode 100644 index 0000000000..61e3653b1a --- /dev/null +++ b/aie_runtime_lib/AIE/tanh.cpp @@ -0,0 +1,12 @@ +//===--- tanh.cpp - tanh loopup tables ---===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2023 Advanced Micro Devices, Inc. +// +// +//===----------------------------------------------------------------------===// +// These are tanh lookup tables for bfloat16 type +//===----------------------------------------------------------------------===// diff --git a/aie_runtime_lib/AIE/tanh.h b/aie_runtime_lib/AIE/tanh.h new file mode 100644 index 0000000000..aa8c858156 --- /dev/null +++ b/aie_runtime_lib/AIE/tanh.h @@ -0,0 +1,19 @@ +//===- tanh.h - get hyperbolic tangent values based on linear approximation +//-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2023 Advanced Micro Devices, Inc. +// +// +//===----------------------------------------------------------------------===// +// This is the implementation of compute hyperbolic tangent values based on +// linear approximation +//===----------------------------------------------------------------------===// + +#ifndef __TANH__ +#define __TANH__ + +#endif //__TANH__ diff --git a/aie_runtime_lib/AIE2/tanh.cpp b/aie_runtime_lib/AIE2/tanh.cpp new file mode 100644 index 0000000000..e2c600ff67 --- /dev/null +++ b/aie_runtime_lib/AIE2/tanh.cpp @@ -0,0 +1,148 @@ +//===--- tanh.cpp - tanh loopup tables ---===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2023 Advanced Micro Devices, Inc. +// +// +//===----------------------------------------------------------------------===// +// These are tanh lookup tables for bfloat16 type +//===----------------------------------------------------------------------===// + +// Divides into 32 segments between [-4,4], bank size: (32*2*2*4)*2=1k, one +// lut=512B +float chess_storage(% chess_alignof(v32int8)) tanh_lut_ab[128] = { + 0.00000000000000000000000000000000, -1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, -0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, -1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, -0.98828125000000000000000000000000, + 0.00509643554687500000000000000000, -0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, -0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, -0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, -0.97265625000000000000000000000000, + 0.01269531250000000000000000000000, -0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, -0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, -0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, -0.93359375000000000000000000000000, + 0.03540039062500000000000000000000, -0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, -0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, -0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, -0.85156250000000000000000000000000, + 0.09179687500000000000000000000000, -0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, -0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, -0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, -0.68750000000000000000000000000000, + 0.22949218750000000000000000000000, -0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, -0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, -0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, -0.41601562500000000000000000000000, + 0.50390625000000000000000000000000, -0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, -0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, -0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, -0.11962890625000000000000000000000, + 0.86718750000000000000000000000000, -0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, -0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, 0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, 0.03076171875000000000000000000000, + 0.69140625000000000000000000000000, 0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, 0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, 0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, 0.25976562500000000000000000000000, + 0.34765625000000000000000000000000, 0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, 0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, 0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, 0.56250000000000000000000000000000, + 0.14550781250000000000000000000000, 0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, 0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, 0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, 0.78125000000000000000000000000000, + 0.05639648437500000000000000000000, 0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, 0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, 0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, 0.89843750000000000000000000000000, + 0.02124023437500000000000000000000, 0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, 0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, 0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, 0.95703125000000000000000000000000, + 0.00750732421875000000000000000000, 0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, 0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, 0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, 0.98046875000000000000000000000000, + 0.00283813476562500000000000000000, 0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, 1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, 0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, 1.00000000000000000000000000000000, +}; + +float chess_storage(% chess_alignof(v32int8)) tanh_lut_cd[128] = { + 0.00000000000000000000000000000000, -1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, -0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, -1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, -0.98828125000000000000000000000000, + 0.00509643554687500000000000000000, -0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, -0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, -0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, -0.97265625000000000000000000000000, + 0.01269531250000000000000000000000, -0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, -0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, -0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, -0.93359375000000000000000000000000, + 0.03540039062500000000000000000000, -0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, -0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, -0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, -0.85156250000000000000000000000000, + 0.09179687500000000000000000000000, -0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, -0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, -0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, -0.68750000000000000000000000000000, + 0.22949218750000000000000000000000, -0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, -0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, -0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, -0.41601562500000000000000000000000, + 0.50390625000000000000000000000000, -0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, -0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, -0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, -0.11962890625000000000000000000000, + 0.86718750000000000000000000000000, -0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, -0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, 0.03076171875000000000000000000000, + 1.00000000000000000000000000000000, 0.00000000000000000000000000000000, + 0.86718750000000000000000000000000, 0.03076171875000000000000000000000, + 0.69140625000000000000000000000000, 0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, 0.25976562500000000000000000000000, + 0.69140625000000000000000000000000, 0.11962890625000000000000000000000, + 0.50390625000000000000000000000000, 0.25976562500000000000000000000000, + 0.34765625000000000000000000000000, 0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, 0.56250000000000000000000000000000, + 0.34765625000000000000000000000000, 0.41601562500000000000000000000000, + 0.22949218750000000000000000000000, 0.56250000000000000000000000000000, + 0.14550781250000000000000000000000, 0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, 0.78125000000000000000000000000000, + 0.14550781250000000000000000000000, 0.68750000000000000000000000000000, + 0.09179687500000000000000000000000, 0.78125000000000000000000000000000, + 0.05639648437500000000000000000000, 0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, 0.89843750000000000000000000000000, + 0.05639648437500000000000000000000, 0.85156250000000000000000000000000, + 0.03540039062500000000000000000000, 0.89843750000000000000000000000000, + 0.02124023437500000000000000000000, 0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, 0.95703125000000000000000000000000, + 0.02124023437500000000000000000000, 0.93359375000000000000000000000000, + 0.01269531250000000000000000000000, 0.95703125000000000000000000000000, + 0.00750732421875000000000000000000, 0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, 0.98046875000000000000000000000000, + 0.00750732421875000000000000000000, 0.97265625000000000000000000000000, + 0.00509643554687500000000000000000, 0.98046875000000000000000000000000, + 0.00283813476562500000000000000000, 0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, 1.00000000000000000000000000000000, + 0.00283813476562500000000000000000, 0.98828125000000000000000000000000, + 0.00000000000000000000000000000000, 1.00000000000000000000000000000000, +}; diff --git a/aie_runtime_lib/AIE2/tanh.h b/aie_runtime_lib/AIE2/tanh.h new file mode 100644 index 0000000000..c956010cea --- /dev/null +++ b/aie_runtime_lib/AIE2/tanh.h @@ -0,0 +1,49 @@ +//===- tanh.h - get hyperbolic tangent values based on linear approximation +//-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2023 Advanced Micro Devices, Inc. +// +// +//===----------------------------------------------------------------------===// +// This is the implementation of compute hyperbolic tangent values based on +// linear approximation +//===----------------------------------------------------------------------===// + +#ifndef __TANH__ +#define __TANH__ + +#include "aie_api/aie.hpp" +#include +#include + +extern float tanh_lut_ab[]; +extern float tanh_lut_cd[]; + +v16bfloat16 __attribute__((always_inline)) getTanhBf16(v16bfloat16 vInput) { + aie::vector input = vInput; + + int step_bits = -2; + int bias = 16; + int data_size = 16; + int LUT_elems = 32; + int shift_offset = 0; // unused + + using lut_type = aie::lut<4, float, bfloat16>; + + lut_type test_lut(LUT_elems, (bfloat16 *)tanh_lut_ab, + (bfloat16 *)tanh_lut_cd); + + aie::linear_approx lin_aprox(test_lut, step_bits, bias, + shift_offset); + + aie::vector output = + lin_aprox.compute(input).to_vector(); + + return (v16bfloat16)output; +} + +#endif //__TANH__ diff --git a/aie_runtime_lib/CMakeLists.txt b/aie_runtime_lib/CMakeLists.txt index 17c60d2105..739ed5f6a9 100644 --- a/aie_runtime_lib/CMakeLists.txt +++ b/aie_runtime_lib/CMakeLists.txt @@ -34,7 +34,9 @@ function(add_aie_runtime_libs arch) set(INSTALLS chess_intrinsic_wrapper.cpp lut_based_ops.cpp - lut_based_ops.h) + lut_based_ops.h + tanh.h + tanh.cpp) foreach(file ${INSTALLS}) add_custom_target(aie-copy-${arch}-runtime-libs-${file} ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${file}) diff --git a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp index 83a8d07681..6e77665372 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp @@ -1791,14 +1791,49 @@ struct ComputeInvOpByLUTPattern : public OpConversionPattern { arith::TruncFOp truncOp = cast(*divOp->getUsers().begin()); rewriter.setInsertionPoint(truncOp); - auto funcOp = rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( truncOp, TypeRange{truncOp.getResult().getType()}, "getInvBf16", nullptr, nullptr, invOperands); rewriter.eraseOp(divOp); - moduleOp = funcOp->getParentOfType(); return success(); } }; + +// Convert math.tanh to a function call to compute tanh(x) by look up tables +struct ComputeTanhOpByLUTPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType srcType = dyn_cast(tanhOp.getOperand().getType()); + Type scalarType = srcType.getElementType(); + if (!srcType || !isa(scalarType)) { + return failure(); + } + + unsigned laneSize = getVectorLaneSize(srcType); + unsigned elWidth = scalarType.getIntOrFloatBitWidth(); + + if (elWidth != 16 || laneSize != 16) { + return failure(); + } + + StringRef includeName = "tanh.h"; + ModuleOp moduleOp = tanhOp->getParentOfType(); + rewriter.setInsertionPointToStart( + &moduleOp.getRegion().getBlocks().front()); + rewriter.create(moduleOp.getLoc(), includeName, false); + + rewriter.setInsertionPoint(tanhOp); + SmallVector tanhOperands = {adaptor.getOperand()}; + rewriter.replaceOpWithNewOp( + tanhOp, TypeRange{tanhOp.getResult().getType()}, "getTanhBf16", nullptr, + nullptr, tanhOperands); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pattern collection //===----------------------------------------------------------------------===// @@ -1830,6 +1865,7 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, LowerVectorSubIOpToAIEVecSubElemOp, ComputeExpOpByLUTPattern, ComputeInvOpByLUTPattern, + ComputeTanhOpByLUTPattern, ConvertMulIToAIEVecMulElemOpPattern, LowerVectorAddFOpToAIEVecAddElemOp, LowerVectorSubFOpToAIEVecSubElemOp, @@ -1903,6 +1939,22 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, return false; }); + target.addDynamicallyLegalOp([](math::TanhOp tanhOp) { + VectorType srcType = dyn_cast(tanhOp.getOperand().getType()); + Type scalarType = srcType.getElementType(); + if (!srcType || !isa(scalarType)) { + return true; + } + + unsigned laneSize = getVectorLaneSize(srcType); + unsigned elWidth = scalarType.getIntOrFloatBitWidth(); + if (elWidth != 16 || laneSize != 16) { + return true; + } + + return false; + }); + target.addDynamicallyLegalOp( [](arith::AddIOp op) { return !isa(op.getType()); }); target.addDynamicallyLegalOp( diff --git a/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp b/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp index b3520e4c9c..2da75ca84c 100644 --- a/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp +++ b/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp @@ -2045,9 +2045,13 @@ static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { raw_ostream &os = emitter.ostream(); Operation &op = *callOp.getOperation(); - - if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true))) - return failure(); + if (callOp.getCallee() == "getTanhBf16") { + if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ false))) + return failure(); + } else { + if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true))) + return failure(); + } os << callOp.getCallee(); auto emitArgs = [&](Attribute attr) -> LogicalResult { diff --git a/test/unit_tests/aievec_tests/bf16_tanh/bf16_tanh.mlir b/test/unit_tests/aievec_tests/bf16_tanh/bf16_tanh.mlir new file mode 100644 index 0000000000..8ec8ef72e0 --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_tanh/bf16_tanh.mlir @@ -0,0 +1,16 @@ +// REQUIRES: valid_xchess_license +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg))" -o linalg.mlir +// RUN: mlir-opt linalg.mlir --linalg-fuse-elementwise-ops --eliminate-empty-tensors --empty-tensor-to-alloc-tensor --one-shot-bufferize="allow-return-allocs allow-unknown-ops bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" --drop-equivalent-buffer-results --buffer-results-to-out-params --buffer-deallocation --canonicalize --cse --convert-linalg-to-affine-loops --affine-super-vectorize="virtual-vector-size=16" -o affine.mlir +// RUN: aie-opt affine.mlir --convert-vector-to-aievec="aie-target=aieml" -lower-affine -o aievec.mlir +// RUN: aie-translate aievec.mlir -aieml=true --aievec-to-cpp -o dut.cc +// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 %aie_runtime_lib%/AIE2/tanh.cpp -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc dut.cc +// RUN: mkdir -p data +// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout +// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s +// CHECK: TEST PASSED +// Cycle count: 807 + +func.func @dut(%arg0: tensor<1024xbf16>) -> (tensor<1024xbf16>) { + %0 = "tosa.tanh"(%arg0) : (tensor<1024xbf16>) -> tensor<1024xbf16> + return %0 : tensor<1024xbf16> +} diff --git a/test/unit_tests/aievec_tests/bf16_tanh/defines.h b/test/unit_tests/aievec_tests/bf16_tanh/defines.h new file mode 100644 index 0000000000..3c6fc96a69 --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_tanh/defines.h @@ -0,0 +1,3 @@ +#pragma once +constexpr unsigned const IN0_SIZE = 1024; +constexpr unsigned const OUT0_SIZE = 1024; diff --git a/test/unit_tests/aievec_tests/bf16_tanh/dut.cc b/test/unit_tests/aievec_tests/bf16_tanh/dut.cc new file mode 100644 index 0000000000..1df19232b2 --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_tanh/dut.cc @@ -0,0 +1,13 @@ +#include "tanh.h" +void dut(bfloat16 *restrict v1, bfloat16 *restrict v2) { + size_t v3 = 0; + size_t v4 = 1024; + size_t v5 = 16; + for (size_t v6 = v3; v6 < v4; v6 += v5) + chess_prepare_for_pipelining chess_loop_range(64, 64) { + v16bfloat16 v7 = *(v16bfloat16 *)(v1 + v6); + v16bfloat16 v8 = getTanhBf16(v7); + *(v16bfloat16 *)(v2 + v6) = v8; + } + return; +} diff --git a/test/unit_tests/aievec_tests/bf16_tanh/testbench.cc b/test/unit_tests/aievec_tests/bf16_tanh/testbench.cc new file mode 100644 index 0000000000..3069b2bd4e --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_tanh/testbench.cc @@ -0,0 +1,60 @@ +#include "../common/testbench.h" +#include "defines.h" +#include +#include +#include +#include + +void dut(bfloat16 *restrict in0, bfloat16 *restrict out0); +void dut_ref(bfloat16 *in0, bfloat16 *out0); + +alignas(32) bfloat16 g_in0[IN0_SIZE]; +alignas(32) bfloat16 g_out0[OUT0_SIZE]; +alignas(32) bfloat16 g_out0Ref[OUT0_SIZE]; + +int main(int argc, char *argv[]) { + std::string dataDir(TO_STR(DATA_DIR)); + srand(10); + std::generate(g_in0, g_in0 + IN0_SIZE, + [&]() { return random_bfloat16(-4, 4, 3); }); + + writeData(g_in0, IN0_SIZE, dataDir + "/in0.txt"); + + chess_memory_fence(); + auto cyclesBegin = chess_cycle_count(); + dut(g_in0, g_out0); + auto cyclesEnd = chess_cycle_count(); + chess_memory_fence(); + + auto cycleCount = (int)(cyclesEnd - cyclesBegin); + reportCycleCount(cycleCount, dataDir + "/cycle_count.txt"); + + writeData(g_out0, OUT0_SIZE, dataDir + "/out0.txt"); + cyclesBegin = chess_cycle_count(); + dut_ref(g_in0, g_out0Ref); + cyclesEnd = chess_cycle_count(); + chess_memory_fence(); + cycleCount = (int)(cyclesEnd - cyclesBegin); + reportCycleCount(cycleCount, dataDir + "/cycle_count.txt"); + writeData(g_out0Ref, OUT0_SIZE, dataDir + "/out0_ref.txt"); + + bool ok = true; + ok &= checkData(g_out0, g_out0Ref, OUT0_SIZE, 0, 1e-2, 1e-2); + + if (ok) + printf("TEST PASSED\n"); + else + printf("TEST FAILED\n"); + + return ok ? 0 : 1; +} + +void dut_ref(bfloat16 *in0, bfloat16 *out0) { + for (unsigned k = 0; k < OUT0_SIZE; k += 1) { + float in = in0[k]; + float expV1 = exp(in); + float expV2 = exp(-in); + float out = (expV1 - expV2) / (expV1 + expV2); + out0[k] = (bfloat16)out; + } +}