diff --git a/include/alpaka/atomic/AtomicGenericSycl.hpp b/include/alpaka/atomic/AtomicGenericSycl.hpp index 8ebf608dc570..bdfa53baf250 100644 --- a/include/alpaka/atomic/AtomicGenericSycl.hpp +++ b/include/alpaka/atomic/AtomicGenericSycl.hpp @@ -51,47 +51,14 @@ namespace alpaka static constexpr auto value = sycl::memory_scope::work_group; }; - template - inline auto get_global_ptr(T* const addr) - { - return sycl::address_space_cast( - addr); - } - - template - inline auto get_local_ptr(T* const addr) - { - return sycl::address_space_cast( - addr); - } - - template - using global_ref = sycl::atomic_ref< - T, - sycl::memory_order::relaxed, - SyclMemoryScope::value, - sycl::access::address_space::global_space>; - template - using local_ref = sycl::atomic_ref< - T, - sycl::memory_order::relaxed, - SyclMemoryScope::value, - sycl::access::address_space::local_space>; + using sycl_atomic_ref = sycl::atomic_ref::value>; template inline auto callAtomicOp(T* const addr, TOp&& op) { - if(auto ptr = get_global_ptr(addr); ptr != nullptr) - { - auto ref = global_ref{*addr}; - return op(ref); - } - else - { - auto ref = local_ref{*addr}; - return op(ref); - } + auto ref = sycl_atomic_ref{*addr}; + return op(ref); } template @@ -178,7 +145,7 @@ namespace alpaka::trait struct AtomicOp { static_assert( - (std::is_integral_v || std::is_floating_point_v) &&(sizeof(T) == 4 || sizeof(T) == 8), + (std::is_integral_v || std::is_floating_point_v) and(sizeof(T) == 4 || sizeof(T) == 8), "SYCL atomics do not support this type"); static auto atomicOp(AtomicGenericSycl const&, T* const addr, T const& value) -> T @@ -200,10 +167,7 @@ namespace alpaka::trait { auto inc = [&value](auto old_val) { return (old_val >= value) ? static_cast(0) : (old_val + static_cast(1)); }; - if(auto ptr = alpaka::detail::get_global_ptr(addr); ptr != nullptr) - return alpaka::detail::casWithCondition>(addr, inc); - else - return alpaka::detail::casWithCondition>(addr, inc); + return alpaka::detail::casWithCondition>(addr, inc); } }; @@ -220,10 +184,7 @@ namespace alpaka::trait { auto dec = [&value](auto& old_val) { return ((old_val == 0) || (old_val > value)) ? value : (old_val - static_cast(1)); }; - if(auto ptr = alpaka::detail::get_global_ptr(addr); ptr != nullptr) - return alpaka::detail::casWithCondition>(addr, dec); - else - return alpaka::detail::casWithCondition>(addr, dec); + return alpaka::detail::casWithCondition>(addr, dec); } }; @@ -294,16 +255,7 @@ namespace alpaka::trait return expected_; }; - if(auto ptr = alpaka::detail::get_global_ptr(addr); ptr != nullptr) - { - auto ref = alpaka::detail::global_ref{*addr}; - return cas(ref); - } - else - { - auto ref = alpaka::detail::local_ref{*addr}; - return cas(ref); - } + return alpaka::detail::callAtomicOp(addr, cas); } }; } // namespace alpaka::trait