Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup hlo_extractor and hlo_bisect dependecies #18170

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 277 additions & 0 deletions xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,184 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
literal.Set<NativeT>(multi_index, scalar.Get<NativeT>({}));
}

template <typename FloatT>
void PopulateWithIntNext(Literal* literal) {
using BitRepT = UnsignedIntegerTypeForSizeType<sizeof(FloatT)>;
// Duplicates may be generated if we don't have enough bits.
// Skip bfloat16 and float32 subnormals.
const FloatT kFirstValue =
std::is_same_v<FloatT, bfloat16> || sizeof(FloatT) >= sizeof(float)
? std::numeric_limits<FloatT>::min()
: std::numeric_limits<FloatT>::denorm_min();
// `current` keeps track of the next value we need to populate.
auto current = literal->data<FloatT>().begin();
auto end = literal->data<FloatT>().end();
// `sign` keeps track of the sign of the next value.
bool sign = false;
while (current != end) {
// We start populating values at zero and increase magnitude from there.
*current = sign ? static_cast<FloatT>(-0.0f) : static_cast<FloatT>(0.0f);
current++;
// The next value is either the smallest denormal or normal.
auto value = sign ? -kFirstValue : kFirstValue;
// Fill the array with values of increasing magnitude until we hit a
// non-finite value.
while (current != end && Eigen::numext::isfinite(value)) {
// Populate the value.
*current = value;
// Generate the next value by lexicographically increasing the bit
// representation.
const BitRepT next_value = Eigen::numext::bit_cast<BitRepT>(value) + 1;
value = Eigen::numext::bit_cast<FloatT>(next_value);
current++;
}
// We ran out of finite values, flip the sign and begin again.
sign = !sign;
}
}

template <typename FloatT>
void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
PopulateWithIntNext<FloatT>(literal);
std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
*engine);
}

// Populates a floating point literal with random floating points sampled from a
// uniform-log distribution spanning approximately the entire range of the
// representable floating point.
template <typename FloatT>
void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
std::minstd_rand0* engine) {
constexpr float kSpecialValueProbability = 1e-6;
constexpr float kSpecialValues[] = {+0.F,
-0.F,
1.F,
-1.F,
std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity()};
constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
std::uniform_real_distribution<float> special_value_gen(0, 1);

// Generates floating points with a log-uniform distribution. This causes the
// exponent of the floating point to have a uniform distribution.
const int min_exp = std::numeric_limits<FloatT>::min_exponent;
const int max_exp = std::numeric_limits<FloatT>::max_exponent;
std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);

for (FloatT& value : literal->data<FloatT>()) {
// Each special value has a kSpecialValueProbability chance to be generated
// instead of sampling using the normal distributions.
if (special_value_gen(*engine) <
kSpecialValueProbability * kNumSpecialValues) {
value =
static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
} else {
float sign = ((*engine)() % 2 == 0) ? 1 : -1;
value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
}
}
}

template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointData(Literal* literal,
std::minstd_rand0* engine) {
std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
for (FloatT& value : literal->data<FloatT>()) {
value = static_cast<FloatT>(generator(*engine));
}
}

template <typename FloatT>
void PopulateWithFloatingPointData(
Literal* literal, std::minstd_rand0* engine, bool no_duplicates,
bool use_large_range, std::optional<int64_t> max_bits_of_precision) {
using ComputeT =
std::conditional_t<sizeof(FloatT) < sizeof(float), float, FloatT>;
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
if (max_bits_of_precision.has_value()) {
CHECK(!use_large_range) << "Cannot set both use_large_range and "
"max_bits_of_precision for floating points.";
CHECK(!no_duplicates) << "Cannot set both no_duplicates and "
"max_bits_of_precision for floating points.";
std::uniform_int_distribution<int64_t> generator(
-(1 << *max_bits_of_precision), 1 << *max_bits_of_precision);
for (FloatT& value : literal->data<FloatT>()) {
int64_t temp = generator(*engine);
// We want to generate floating point numbers to a fixed precision, while
// keeping them between -1 and 1. This preserves their bits of precision
// while keeping the numbers small.
value = static_cast<FloatT>(temp * pow(2, -ceil(log2(abs(temp)))));
}
} else if (no_duplicates) {
PopulateWithNoDuplicateData<FloatT>(literal, engine);
} else if (use_large_range) {
PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
} else {
PopulateWithRandomFloatingPointData<FloatT, ComputeT>(literal, engine);
}
}

template <typename ComplexT>
void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
bool no_duplicates, bool use_large_range) {
using InnerFloatT = typename ComplexT::value_type;
CHECK(engine != nullptr);
CHECK_EQ(result->shape().element_type(),
primitive_util::NativeToPrimitiveType<ComplexT>());
Shape floating_point_shape = ShapeUtil::ChangeElementType(
result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
Literal real_lit(floating_point_shape);
Literal imaginary_lit(floating_point_shape);

PopulateWithFloatingPointData<InnerFloatT>(
&real_lit, engine, no_duplicates, use_large_range,
/*max_bits_of_precision=*/std::nullopt);
PopulateWithFloatingPointData<InnerFloatT>(
&imaginary_lit, engine, no_duplicates, use_large_range,
/*max_bits_of_precision=*/std::nullopt);

absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
absl::Span<const InnerFloatT> imaginary_data =
imaginary_lit.data<InnerFloatT>();
absl::Span<ComplexT> result_data = result->data<ComplexT>();
for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
}
}

// uniform_int_distribution is not defined for 8-bit integers.
// Use 'short' for those types.
template <typename IntT>
using RngT = std::conditional_t<
sizeof(IntT) < sizeof(uint16_t),
std::conditional_t<std::numeric_limits<IntT>::is_signed, int16_t, uint16_t>,
IntT>;
template <typename IntT>
void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
std::minstd_rand0* engine,
bool no_duplicates, IntT min,
IntT max) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
if (no_duplicates &&
ShapeUtil::ElementsIn(literal->shape()) < static_cast<int64_t>(max)) {
std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(),
static_cast<IntT>(0));
std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
*engine);
} else {
std::uniform_int_distribution<RngT<IntT>> generator(
static_cast<RngT<IntT>>(min), static_cast<RngT<IntT>>(max));
for (IntT& value : literal->data<IntT>()) {
value = static_cast<IntT>(generator(*engine));
}
}
}

} // namespace

/* static */ Literal LiteralUtil::CreateFromDimensions(
Expand Down Expand Up @@ -498,4 +676,103 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
return l.GetFirstInteger();
}

absl::StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
bool use_large_range) {
auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteral(shape, engine.get(), /*limit=*/std::nullopt,
/*is_sorted=*/false,
/*no_duplicates=*/false, use_large_range,
/*max_bits_of_precision=*/std::nullopt);
}

absl::StatusOr<Literal> MakeFakeLiteral(
const Shape& shape, std::minstd_rand0* engine,
std::optional<std::pair<int64_t, int64_t>> limit, bool is_sorted,
bool no_duplicates, bool use_large_range,
std::optional<int64_t> max_bits_of_precision) {
if (shape.IsTuple()) {
std::vector<Literal> elements;
const auto& shape_tuple_shapes = shape.tuple_shapes();
elements.reserve(shape_tuple_shapes.size());
for (const Shape& element_shape : shape_tuple_shapes) {
TF_ASSIGN_OR_RETURN(
Literal element,
MakeFakeLiteral(element_shape, engine, limit, is_sorted,
no_duplicates, use_large_range,
max_bits_of_precision));
elements.push_back(std::move(element));
}
return LiteralUtil::MakeTupleOwned(std::move(elements));
}
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
// Clear tiles/element size in shape's layout before using it for creating
// literal.
Shape new_shape = shape;
new_shape.mutable_layout()->clear_tiles();
new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1);
new_shape.mutable_layout()->set_element_size_in_bits(0);
Literal literal(new_shape);

TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch<absl::Status>(
[&](auto primitive_type_constant) -> absl::Status {
if constexpr (primitive_util::IsArrayType(primitive_type_constant)) {
using NativeT = primitive_util::NativeTypeOf<primitive_type_constant>;
if constexpr (primitive_util::IsFloatingPointType(
primitive_type_constant)) {
PopulateWithFloatingPointData<NativeT>(
&literal, engine, no_duplicates, use_large_range,
max_bits_of_precision);
return absl::OkStatus();
}
if constexpr (primitive_type_constant == PRED) {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal.Populate<bool>(
[&](absl::Span<const int64_t> /*indices*/) {
return generator(*engine);
}));
return absl::OkStatus();
}
if constexpr (primitive_util::IsIntegralType(
primitive_type_constant)) {
NativeT max = std::numeric_limits<NativeT>::max();
NativeT min = std::numeric_limits<NativeT>::lowest();
if (limit.has_value()) {
max = static_cast<NativeT>(limit->second);
min = static_cast<NativeT>(limit->first);
}
if (max_bits_of_precision.has_value()) {
max = std::min(max,
static_cast<NativeT>(1 << *max_bits_of_precision));
if (primitive_util::IsSignedIntegralType(
primitive_type_constant)) {
min = std::max(
min, static_cast<NativeT>(-(1 << *max_bits_of_precision)));
}
}
PopulateWithRandomIntegralDataWithBounds<NativeT>(
&literal, engine, /*no_duplicate*/ no_duplicates, min, max);
if (is_sorted) {
std::sort(literal.data<NativeT>().begin(),
literal.data<NativeT>().end());
}
return absl::OkStatus();
}
if constexpr (primitive_util::IsComplexType(
primitive_type_constant)) {
PopulateWithComplexData<NativeT>(&literal, engine, no_duplicates,
use_large_range);
return absl::OkStatus();
}
}
return Unimplemented(
"Unsupported type for fake random literal generation with bounds: "
"%s",
ShapeUtil::HumanString(shape));
},
shape.element_type()));
return std::move(literal);
}

} // namespace xla
26 changes: 26 additions & 0 deletions xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,32 @@ template <PrimitiveType type, typename T>
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}

// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random and use_large_range.
absl::StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
bool pseudo_random = true,
bool use_large_range = false);

// Similar to MakeFakeLiteral above but takes a random number generator engine
// to enable reusing the engine across randomly generated literals. 'limit' is a
// optional pair that contains the min and the max values to be sample for
// integers (integer format only). 'is_sorted' sorts the sample data for
// integers (integer format only). 'no_duplicates' indicates that there should
// be no duplicate values in each generated array. This is uniqueness is
// best-effort only. Some types (half and bfloat16) are not supported and
// uniqueness cannot be guaranteed if the number of elements exceeds the number
// of different values supported by the type. (floating point format only)
// 'use_large_range' indicates the sampled data is from the full range of the
// floating point format. (floating point format only)
// 'max_bits_of_precision' sets the data to have the given number of bits or
// less (integer or floating point formats only).
absl::StatusOr<Literal> MakeFakeLiteral(
const Shape& shape, std::minstd_rand0* engine,
std::optional<std::pair<int64_t, int64_t>> limit, bool is_sorted,
bool no_duplicates, bool use_large_range,
std::optional<int64_t> max_bits_of_precision);

} // namespace xla

#endif // XLA_LITERAL_UTIL_H_
Loading
Loading