Skip to content

Commit

Permalink
Enable heterogeneous insert for static_map (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack authored Oct 12, 2023
1 parent 72ca959 commit e37c12d
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 55 deletions.
29 changes: 17 additions & 12 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <thrust/distance.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>

#include <cuda/atomic>

Expand Down Expand Up @@ -865,9 +866,9 @@ class open_addressing_ref_impl {
Value const& value) const noexcept
{
if constexpr (this->has_payload) {
return value.first;
return thrust::get<0>(thrust::raw_reference_cast(value));
} else {
return value;
return thrust::raw_reference_cast(value);
}
}

Expand All @@ -886,7 +887,7 @@ class open_addressing_ref_impl {
[[nodiscard]] __host__ __device__ constexpr auto const& extract_payload(
Value const& value) const noexcept
{
return value.second;
return thrust::get<1>(thrust::raw_reference_cast(value));
}

/**
Expand Down Expand Up @@ -952,10 +953,10 @@ class open_addressing_ref_impl {
auto const expected_key = expected.first;
auto const expected_payload = expected.second;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));
auto old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);
Expand All @@ -964,7 +965,7 @@ class open_addressing_ref_impl {
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -973,7 +974,9 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -999,20 +1002,22 @@ class open_addressing_ref_impl {

auto const expected_key = expected.first;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(desired.second));
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));
return insert_result::SUCCESS;
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair<T1, T2> const& lhs,
}

} // namespace cuco

namespace thrust {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace thrust

namespace cuda::std {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace cuda::std
118 changes: 118 additions & 0 deletions include/cuco/detail/pair/tuple_helpers.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

template <typename T1, typename T2>
struct tuple_size<cuco::pair<T1, T2>> : integral_constant<size_t, 2> {
};

template <typename T1, typename T2>
struct tuple_size<const cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_size<volatile cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_size<const volatile cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <std::size_t I, typename T1, typename T2>
struct tuple_element<I, cuco::pair<T1, T2>> {
using type = void;
};

template <typename T1, typename T2>
struct tuple_element<0, cuco::pair<T1, T2>> {
using type = T1;
};

template <typename T1, typename T2>
struct tuple_element<1, cuco::pair<T1, T2>> {
using type = T2;
};

template <typename T1, typename T2>
struct tuple_element<0, const cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, const cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<0, volatile cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, volatile cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<0, const volatile cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, const volatile cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2>& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type&
{
static_assert(I < 2);
if constexpr (I == 0) {
return p.first;
} else {
return p.second;
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2>&& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type&&
{
static_assert(I < 2);
if constexpr (I == 0) {
return std::move(p.first);
} else {
return std::move(p.second);
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2> const& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type const&
{
static_assert(I < 2);
if constexpr (I == 0) {
return p.first;
} else {
return p.second;
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2> const&& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type const&&
{
static_assert(I < 2);
if constexpr (I == 0) {
return std::move(p.first);
} else {
return std::move(p.second);
}
}
11 changes: 6 additions & 5 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cub/block/block_reduce.cuh>

#include <cuda/atomic>
#include <iterator>

#include <cooperative_groups.h>

Expand All @@ -39,22 +40,22 @@ namespace detail {
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIterator, typename Ref>
__global__ void insert_or_assign(InputIterator first, cuco::detail::index_type n, Ref ref)
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename Ref>
__global__ void insert_or_assign(InputIt first, cuco::detail::index_type n, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename Ref::value_type const insert_pair{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_pair = *(first + idx);
if constexpr (CGSize == 1) {
ref.insert_or_assign(insert_pair);
} else {
Expand Down Expand Up @@ -100,7 +101,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
38 changes: 30 additions & 8 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,14 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
__device__ bool insert(value_type const& value) noexcept
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value);
Expand All @@ -192,12 +196,16 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
template <typename Value>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value);
Expand Down Expand Up @@ -230,9 +238,12 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*/
__device__ void insert_or_assign(value_type const& value) noexcept
template <typename Value>
__device__ void insert_or_assign(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

Expand Down Expand Up @@ -275,11 +286,14 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*/
template <typename Value>
__device__ void insert_or_assign(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);

Expand Down Expand Up @@ -350,13 +364,15 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return Returns `true` if the given `value` is inserted or `value` has a match in the map.
*/
__device__ constexpr bool attempt_insert_or_assign(value_type* slot,
value_type const& value) noexcept
template <typename Value>
__device__ constexpr bool attempt_insert_or_assign(value_type* slot, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const expected_key = ref_.impl_.empty_slot_sentinel().first;
Expand Down Expand Up @@ -430,12 +446,15 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value);
Expand All @@ -448,14 +467,17 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value);
Expand Down
Loading

0 comments on commit e37c12d

Please sign in to comment.