Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove use_bfloat16 from reduce_window_test.cc #18669

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1772,24 +1772,28 @@ xla_test_library(
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//xla:array2d",
"//xla:array3d",
"//xla:array4d",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:reference_util",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/client:local_client",
"//xla/hlo/builder:padding",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/builder/lib:arithmetic",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:status",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Expand Down
24 changes: 0 additions & 24 deletions xla/tests/client_library_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,6 @@ limitations under the License.

namespace xla {

// Sets the use_bfloat16 on a container of test cases according to the values in
// use_bfloat16_params. Generates one set of test cases for each values in
// use_bfloat16_params with that value. Returns the result.
template <typename TestCase>
std::vector<TestCase> ExpandUseBfloat16(
absl::Span<const bool> use_bfloat16_params,
absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (bool use_bfloat16 : use_bfloat16_params) {
for (const auto& spec : specs) {
expanded.push_back(spec);
expanded.back().use_bfloat16 = use_bfloat16;
}
}
return expanded;
}

template <typename TestCase>
std::vector<TestCase> ExpandTestType(
absl::Span<const PrimitiveType> test_type_params,
Expand Down Expand Up @@ -413,13 +396,6 @@ class ClientLibraryTestBase : public ::testing::Test {
XlaBuilder* builder,
XlaOp* data_handle);

// TODO(ralphnathan): These will eventually be removed. Please have new tests
// support multiple primitive types, not just BF16.
// Getter and setter for the test_type flag, which indicates whether to run
// tests with all float-type input/output converted to bfloat16.
bool use_bfloat16() const { return test_type_ == BF16; }
void set_use_bfloat16(bool value) { test_type_ = value ? BF16 : F32; }

// The float type used in this test.
PrimitiveType FloatType() const { return test_type_; }
void set_float_type(PrimitiveType type) { test_type_ = type; }
Expand Down
119 changes: 69 additions & 50 deletions xla/tests/reduce_window_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,66 @@ limitations under the License.

// Tests the reduce-window XLA operation.

#include <limits>
#include <algorithm>
#include <array>
#include <cstdint>
#include <iterator>
#include <memory>

#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/client/local_client.h"
#include "xla/error_spec.h"
#include "xla/hlo/builder/lib/arithmetic.h"
#include "xla/hlo/builder/padding.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/reference_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/status.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace {

static std::array<bool, 2> use_bfloat16_params{false, true};
static std::array<PrimitiveType, 2> test_type_params = {F32, BF16};

class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
if (FloatType() == BF16) {
return ErrorSpec(2e-1, 6e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
}
};

class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
class ReduceWindowTest : public ::testing::WithParamInterface<PrimitiveType>,
public ReduceWindowTestBase {
public:
ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
ReduceWindowTest() : builder_(TestName()) { set_float_type(GetParam()); }

void ReduceWindowAdd(const XlaOp input,
absl::Span<const int64_t> window_dimensions,
Expand Down Expand Up @@ -563,7 +578,7 @@ XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
}

INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
::testing::ValuesIn(use_bfloat16_params));
::testing::ValuesIn(test_type_params));

enum Reducer { kAdd, kMax };

Expand All @@ -580,7 +595,7 @@ struct R4ReduceWindowTestData {

std::string R4ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
::testing::tuple<R4ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
Expand All @@ -594,17 +609,18 @@ std::string R4ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R4ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R4ReduceWindowTestData, bool>> {
class R4ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R4ReduceWindowTestData, PrimitiveType>> {
protected:
R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R4ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
Expand Down Expand Up @@ -878,7 +894,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R4ReduceWindowTestDataToString);

class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
Expand Down Expand Up @@ -967,7 +983,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R4ReduceWindowTestDataToString);

struct R3ReduceWindowTestData {
Expand Down Expand Up @@ -1017,7 +1033,7 @@ R3ReduceWindowTestData kR3TestCases[] = {

std::string R3ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
::testing::tuple<R3ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
Expand All @@ -1026,17 +1042,18 @@ std::string R3ReduceWindowTestDataToString(
param.padding == Padding::kSame ? "same" : "valid", "__layout_",
param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R3ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R3ReduceWindowTestData, bool>> {
class R3ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R3ReduceWindowTestData, PrimitiveType>> {
protected:
R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R3ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
Expand All @@ -1052,7 +1069,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
auto reducer = param.reducer;
if (use_bfloat16()) {
if (FloatType() == BF16) {
input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);

// To avoid numerical issues, force the reducer to be kMax for bf16
Expand Down Expand Up @@ -1083,7 +1100,7 @@ XLA_TEST_P(R3ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR3TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R3ReduceWindowTestDataToString);

class R3ReduceWindowLargeTest : public R3ReduceWindowTest {};
Expand All @@ -1106,7 +1123,7 @@ const R3ReduceWindowTestData kR3ReduceWindowLargeTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R3ReduceWindowLargeTestInstantiation, R3ReduceWindowLargeTest,
::testing::Combine(::testing::ValuesIn(kR3ReduceWindowLargeTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R3ReduceWindowTestDataToString);

struct R2ReduceWindowTestData {
Expand Down Expand Up @@ -1268,7 +1285,7 @@ struct R2ReduceWindowTestData {

std::string R2ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
::testing::tuple<R2ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
Expand All @@ -1283,24 +1300,25 @@ std::string R2ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R2ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R2ReduceWindowTestData, bool>> {
class R2ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R2ReduceWindowTestData, PrimitiveType>> {
protected:
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R2ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());

Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
if (!::testing::get<1>(GetParam())) {
if (FloatType() == F32) {
// We only do this in F32 mode, to avoid precision issues with BF16.
input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0],
param.base_bounds[1]);
Expand Down Expand Up @@ -1343,7 +1361,7 @@ XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR2TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R2ReduceWindowTestDataToString);

struct R1ReduceWindowTestData {
Expand Down Expand Up @@ -1499,7 +1517,7 @@ struct R1ReduceWindowTestData {

std::string R1ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
::testing::tuple<R1ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str =
absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
Expand All @@ -1511,17 +1529,18 @@ std::string R1ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R1ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R1ReduceWindowTestData, bool>> {
class R1ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R1ReduceWindowTestData, PrimitiveType>> {
protected:
R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R1ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }
};

XLA_TEST_P(R1ReduceWindowTest, DoIt) {
Expand Down Expand Up @@ -1575,7 +1594,7 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) {
INSTANTIATE_TEST_CASE_P(
R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR1TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R1ReduceWindowTestDataToString);

// Test class for text-based test cases. Note that this compares with the
Expand Down
Loading