Skip to content

Commit

Permalink
Make powm1 GPU compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
mborland committed Jul 30, 2024
1 parent 8fc0e39 commit 2e45590
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
6 changes: 3 additions & 3 deletions include/boost/math/special_functions/math_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,11 @@ namespace boost

// Power - 1
template <class T1, class T2>
tools::promote_args_t<T1, T2>
BOOST_MATH_GPU_ENABLED tools::promote_args_t<T1, T2>
powm1(const T1 a, const T2 z);

template <class T1, class T2, class Policy>
tools::promote_args_t<T1, T2>
BOOST_MATH_GPU_ENABLED tools::promote_args_t<T1, T2>
powm1(const T1 a, const T2 z, const Policy&);

// sqrt(1+x) - 1
Expand Down Expand Up @@ -1481,7 +1481,7 @@ namespace boost
\
template <class T1, class T2>\
inline boost::math::tools::promote_args_t<T1, T2> \
powm1(const T1 a, const T2 z){ return boost::math::powm1(a, z, Policy()); }\
BOOST_MATH_GPU_ENABLED powm1(const T1 a, const T2 z){ return boost::math::powm1(a, z, Policy()); }\
\
template <class T>\
BOOST_MATH_GPU_ENABLED inline boost::math::tools::promote_args_t<T> sqrt1pm1(const T& val){ return boost::math::sqrt1pm1(val, Policy()); }\
Expand Down
66 changes: 40 additions & 26 deletions include/boost/math/special_functions/powm1.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// (C) Copyright John Maddock 2006.
// (C) Copyright Matt Borland 2024.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
Expand All @@ -12,6 +13,7 @@
#pragma warning(disable:4702) // Unreachable code (release mode only warning)
#endif

#include <boost/math/tools/config.hpp>
#include <boost/math/special_functions/math_fwd.hpp>
#include <boost/math/special_functions/log1p.hpp>
#include <boost/math/special_functions/expm1.hpp>
Expand All @@ -22,32 +24,23 @@
namespace boost{ namespace math{ namespace detail{

template <class T, class Policy>
inline T powm1_imp(const T x, const T y, const Policy& pol)
BOOST_MATH_GPU_ENABLED inline T powm1_imp(const T x, const T y, const Policy& pol)
{
BOOST_MATH_STD_USING
static const char* function = "boost::math::powm1<%1%>(%1%, %1%)";
if (x > 0)
constexpr auto function = "boost::math::powm1<%1%>(%1%, %1%)";

if ((fabs(y * (x - 1)) < T(0.5)) || (fabs(y) < T(0.2)))
{
if ((fabs(y * (x - 1)) < T(0.5)) || (fabs(y) < T(0.2)))
{
// We don't have any good/quick approximation for log(x) * y
// so just try it and see:
T l = y * log(x);
if (l < T(0.5))
return boost::math::expm1(l, pol);
if (l > boost::math::tools::log_max_value<T>())
return boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
// fall through....
}
}
else if ((boost::math::signbit)(x)) // Need to error check -0 here as well
{
// y had better be an integer:
if (boost::math::trunc(y) != y)
return boost::math::policies::raise_domain_error<T>(function, "For non-integral exponent, expected base > 0 but got %1%", x, pol);
if (boost::math::trunc(y / 2) == y / 2)
return powm1_imp(T(-x), y, pol);
// We don't have any good/quick approximation for log(x) * y
// so just try it and see:
T l = y * log(x);
if (l < T(0.5))
return boost::math::expm1(l, pol);
if (l > boost::math::tools::log_max_value<T>())
return boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
// fall through....
}

T result = pow(x, y) - 1;
if((boost::math::isinf)(result))
return result < 0 ? -boost::math::policies::raise_overflow_error<T>(function, nullptr, pol) : boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
Expand All @@ -56,22 +49,43 @@ inline T powm1_imp(const T x, const T y, const Policy& pol)
return result;
}

template <class T, class Policy>
BOOST_MATH_GPU_ENABLED inline T powm1_imp_dispatch(const T x, const T y, const Policy& pol)
{
BOOST_MATH_STD_USING

if ((boost::math::signbit)(x)) // Need to error check -0 here as well
{
constexpr auto function = "boost::math::powm1<%1%>(%1%, %1%)";

// y had better be an integer:
if (boost::math::trunc(y) != y)
return boost::math::policies::raise_domain_error<T>(function, "For non-integral exponent, expected base > 0 but got %1%", x, pol);
if (boost::math::trunc(y / 2) == y / 2)
return powm1_imp(T(-x), T(y), pol);
}
else
{
return powm1_imp(T(x), T(y), pol);
}
}

} // detail

template <class T1, class T2>
inline typename tools::promote_args<T1, T2>::type
BOOST_MATH_GPU_ENABLED inline typename tools::promote_args<T1, T2>::type
powm1(const T1 a, const T2 z)
{
typedef typename tools::promote_args<T1, T2>::type result_type;
return detail::powm1_imp(static_cast<result_type>(a), static_cast<result_type>(z), policies::policy<>());
return detail::powm1_imp_dispatch(static_cast<result_type>(a), static_cast<result_type>(z), policies::policy<>());
}

template <class T1, class T2, class Policy>
inline typename tools::promote_args<T1, T2>::type
BOOST_MATH_GPU_ENABLED inline typename tools::promote_args<T1, T2>::type
powm1(const T1 a, const T2 z, const Policy& pol)
{
typedef typename tools::promote_args<T1, T2>::type result_type;
return detail::powm1_imp(static_cast<result_type>(a), static_cast<result_type>(z), pol);
return detail::powm1_imp_dispatch(static_cast<result_type>(a), static_cast<result_type>(z), pol);
}

} // namespace math
Expand Down

0 comments on commit 2e45590

Please sign in to comment.