Skip to content

Commit

Permalink
Update CPU reference
Browse files Browse the repository at this point in the history
1. Revert the default threshold of relative difference to (100 * std::numeric_limits<T>::epsilon())
2. Update CPU reference to make the difference between CPU reference and output of contraction instance
is less than (100 * std::numeric_limits<T>::epsilon()).
  • Loading branch information
CongMa13 committed Dec 8, 2023
1 parent fec9065 commit b21fe0b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
29 changes: 21 additions & 8 deletions library/src/contraction/contraction_cpu_reference_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,25 @@
namespace hiptensor
{
// hardcoded for NumDimM == NumDimN == NumDimK == 2
//
// ck::bhalf_t is ushort, cannot perform bhalf_t * bhalf_t
// CK does not use ck::bhalf_t as AccDataType. But we still
// add this guard here
template <
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1
&& !std::is_same_v<AccDataType, ck::bhalf_t>,
bool>
= false>
struct ReferenceContraction_M2_N2_K2
Expand Down Expand Up @@ -151,7 +157,7 @@ namespace hiptensor
};

auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
float accum = 0.0f;
AccDataType accum = 0;

auto K0 = arg.mA_ms_ks_lengths[2];
auto K1 = arg.mA_ms_ks_lengths[3];
Expand All @@ -165,16 +171,19 @@ namespace hiptensor
auto indexB
= offset(std::vector<size_t>{n0, n1, k0, k1}, arg.mB_ns_ks_strides);

ADataType valA;
BDataType valB;
AccDataType valA;
AccDataType valB;

// Element-wise ops
arg.mOpA(valA, ((ADataType*)arg.mA)[indexA]);
arg.mOpB(valB, ((BDataType*)arg.mB)[indexB]);
arg.mOpA(
valA,
ck::type_convert<ComputeDataType>(((ADataType*)arg.mA)[indexA]));
arg.mOpB(
valB,
ck::type_convert<ComputeDataType>(((BDataType*)arg.mB)[indexB]));

// Mult / accum
accum += ck::type_convert<float>(ck::type_convert<ComputeDataType>(
ck::type_convert<float>(valA) * ck::type_convert<float>(valB)));
accum += valA * valB;
}
}

Expand Down Expand Up @@ -322,6 +331,7 @@ namespace hiptensor
ck::index_t NumDimsK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
Expand All @@ -333,6 +343,7 @@ namespace hiptensor
NumDimsK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
Expand All @@ -359,6 +370,7 @@ namespace hiptensor
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
Expand All @@ -372,6 +384,7 @@ namespace hiptensor
NumDimK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
Expand Down
14 changes: 14 additions & 0 deletions library/src/contraction/contraction_cpu_reference_instances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace hiptensor
2,
ck::half_t,
ck::half_t,
float,
ck::Tuple<ck::half_t>,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -53,6 +54,7 @@ namespace hiptensor
2,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::Tuple<ck::bhalf_t>,
ck::bhalf_t,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -67,6 +69,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<float>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -80,6 +83,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<float>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -93,6 +97,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<float>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -107,6 +112,7 @@ namespace hiptensor
2,
double,
double,
float,
ck::Tuple<double>,
double,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -120,6 +126,7 @@ namespace hiptensor
2,
double,
double,
double,
ck::Tuple<double>,
double,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -134,6 +141,7 @@ namespace hiptensor
2,
ck::half_t,
ck::half_t,
float,
ck::Tuple<>,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -148,6 +156,7 @@ namespace hiptensor
2,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::Tuple<>,
ck::bhalf_t,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -162,6 +171,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -175,6 +185,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -188,6 +199,7 @@ namespace hiptensor
2,
float,
float,
float,
ck::Tuple<>,
float,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -202,6 +214,7 @@ namespace hiptensor
2,
double,
double,
float,
ck::Tuple<>,
double,
ck::tensor_operation::element_wise::PassThrough,
Expand All @@ -215,6 +228,7 @@ namespace hiptensor
2,
double,
double,
double,
ck::Tuple<>,
double,
ck::tensor_operation::element_wise::PassThrough,
Expand Down
2 changes: 1 addition & 1 deletion test/01_contraction/configs/bilinear_test_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Betas:
Lengths:
- [ 5, 6, 3, 4, 3, 4 ]
- [ 4, 3, 4, 3, 6, 5 ]
- [ 24, 18, 2, 4, 9, 1 ]
- [ 24, 18, 2, 4, 9, 2 ]
Strides:
- []
...
2 changes: 1 addition & 1 deletion test/01_contraction/configs/scale_test_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Betas:
Lengths:
- [ 5, 6, 3, 4, 3, 4 ]
- [ 4, 3, 4, 3, 6, 5 ]
- [ 24, 18, 2, 4, 9, 1 ]
- [ 24, 18, 2, 4, 9, 2 ]
Strides:
- []
...
9 changes: 5 additions & 4 deletions test/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ template <typename DDataType>
std::pair<bool, double> compareEqual(DDataType const* deviceD,
DDataType const* hostD,
std::size_t elementsD,
double tolerance = 0.001)
double tolerance = 100.0)
{
bool retval = true;
double max_relative_error = 0.0;
Expand Down Expand Up @@ -202,7 +202,7 @@ std::pair<bool, double> compareEqual(DDataType const* deviceD,
retval = false;
max_relative_error = std::numeric_limits<DDataType>::signaling_NaN();
}
else if(max_relative_error > tolerance)
else if(max_relative_error > (eps * tolerance))
{
retval = false;
}
Expand All @@ -214,7 +214,7 @@ template <typename DDataType>
std::pair<bool, double> compareEqualLaunchKernel(DDataType* deviceD,
DDataType* hostD,
std::size_t elementsD,
double tolerance = 0.001)
double tolerance = 100.0)
{
auto blockDim = dim3(1024, 1, 1);
auto gridDim = dim3(ceilDiv(elementsD, blockDim.x), 1, 1);
Expand Down Expand Up @@ -276,12 +276,13 @@ std::pair<bool, double> compareEqualLaunchKernel(DDataType* deviceD,
auto toDouble
= [](DDataType const& val) { return static_cast<double>(static_cast<float>(val)); };

auto eps = toDouble(std::numeric_limits<DDataType>::epsilon());
if(isNaN)
{
retval = false;
maxRelativeError = std::numeric_limits<DDataType>::signaling_NaN();
}
else if(maxRelativeError > tolerance)
else if(maxRelativeError > (eps * tolerance))
{
retval = false;
}
Expand Down

0 comments on commit b21fe0b

Please sign in to comment.