From 882fd073fd85cddd4089e27c6cd003927d8a17f6 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 22 Mar 2024 12:33:33 +0000 Subject: [PATCH] Reslove divergence in array2d.h between local_xla and upstream xla --- third_party/xla/xla/array2d.h | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/array2d.h b/third_party/xla/xla/array2d.h index 957f2d2678e785..2e8c1547a967a3 100644 --- a/third_party/xla/xla/array2d.h +++ b/third_party/xla/xla/array2d.h @@ -95,19 +95,14 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64_t n1, int64_t n2) { auto array = std::make_unique>(n1, n2); int64_t count = n1 * n2; - // For types of smaller widths, do the arithmetics in double, since for - // sufficiently large n1 & n2, this could overflow and generate nans - using ArithT = - typename std::conditional<(sizeof(NativeT) < 4), double, NativeT>::type; - ArithT step = - static_cast((count > 1) ? (to - from) / (count - 1) : 0); + NativeT step = + static_cast((count > 1) ? (to - from) / (count - 1) : 0); auto set = [&array, n2](int64_t index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64_t i = 0; i < count - 1; ++i) { - set(i, static_cast(static_cast(from) + - static_cast(i) * - static_cast(step))); + set(i, (static_cast(from) + + static_cast(i) * static_cast(step))); } set(count - 1, static_cast(to)); return array;