Skip to content

Commit

Permalink
Reslove divergence in array2d.h between local_xla and upstream xla
Browse files Browse the repository at this point in the history
  • Loading branch information
hsharsha committed Mar 22, 2024
1 parent b378542 commit 882fd07
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions third_party/xla/xla/array2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,14 @@ std::unique_ptr<Array2D<NativeT>> MakeLinspaceArray2D(double from, double to,
int64_t n1, int64_t n2) {
auto array = std::make_unique<Array2D<NativeT>>(n1, n2);
int64_t count = n1 * n2;
// For types of smaller widths, do the arithmetics in double, since for
// sufficiently large n1 & n2, this could overflow and generate nans
using ArithT =
typename std::conditional<(sizeof(NativeT) < 4), double, NativeT>::type;
ArithT step =
static_cast<ArithT>((count > 1) ? (to - from) / (count - 1) : 0);
NativeT step =
static_cast<NativeT>((count > 1) ? (to - from) / (count - 1) : 0);
auto set = [&array, n2](int64_t index, NativeT value) {
(*array)(index / n2, index % n2) = value;
};
for (int64_t i = 0; i < count - 1; ++i) {
set(i, static_cast<NativeT>(static_cast<ArithT>(from) +
static_cast<ArithT>(i) *
static_cast<ArithT>(step)));
set(i, (static_cast<NativeT>(from) +
static_cast<NativeT>(i) * static_cast<NativeT>(step)));
}
set(count - 1, static_cast<NativeT>(to));
return array;
Expand Down

0 comments on commit 882fd07

Please sign in to comment.