diff --git a/cpp/include/kvikio/utils.hpp b/cpp/include/kvikio/utils.hpp index 7a54b2793b..acdafba4da 100644 --- a/cpp/include/kvikio/utils.hpp +++ b/cpp/include/kvikio/utils.hpp @@ -21,7 +21,9 @@ #include #include #include +#include #include +#include #include @@ -69,6 +71,29 @@ inline constexpr std::size_t page_size = 4096; return reinterpret_cast(devPtr); } +/** + * @brief Help function to convert value to 64 bit signed integer + */ +template >* = nullptr> +[[nodiscard]] std::int64_t convert_to_64bit(T value) +{ + if constexpr (std::numeric_limits::max() > std::numeric_limits::max()) { + if (value > std::numeric_limits::max()) { + throw std::overflow_error("convert_to_64bit(x): x too large to fit std::int64_t"); + } + } + return std::int64_t(value); +} + +/** + * @brief Help function to convert value to 64 bit float + */ +template >* = nullptr> +[[nodiscard]] double convert_to_64bit(T value) +{ + return double(value); +} + /** * @brief Check if `ptr` points to host memory (as opposed to device memory) * @@ -280,7 +305,7 @@ struct libkvikio_domain { { \ nvtx3::event_attributes \ { \ - msg, nvtx3::payload { val } \ + msg, nvtx3::payload { convert_to_64bit(val) } \ } \ } #define GET_KVIKIO_NVTX_FUNC_RANGE_MACRO(_1, _2, NAME, ...) NAME