Skip to content

Commit

Permalink
update renumber_sampled_edgelist to use the existing expand_sparse_of…
Browse files Browse the repository at this point in the history
…fsets
  • Loading branch information
seunghwak committed Jul 21, 2023
1 parent 9a8abee commit d4d4c78
Showing 1 changed file with 3 additions and 24 deletions.
27 changes: 3 additions & 24 deletions cpp/src/sampling/renumber_sampled_edgelist_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,7 @@ compute_renumber_map(raft::handle_t const& handle,
std::optional<rmm::device_uvector<label_index_t>> edgelist_label_indices{std::nullopt};
if (label_offsets) {
edgelist_label_indices =
rmm::device_uvector<label_index_t>(edgelist_srcs.size(), handle.get_stream());
thrust::transform(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(edgelist_srcs.size()),
(*edgelist_label_indices).begin(),
[offsets = raft::device_span<size_t const>(
(*label_offsets).data() + 1, (*label_offsets).size() - 1)] __device__(size_t i) {
return static_cast<size_t>(thrust::distance(
offsets.begin(), thrust::upper_bound(thrust::seq, offsets.begin(), offsets.end(), i)));
});
detail::expand_sparse_offsets(*label_offsets, label_index_t{0}, handle.get_stream());
}

std::optional<rmm::device_uvector<label_index_t>> unique_label_src_pair_label_indices{
Expand Down Expand Up @@ -635,19 +625,8 @@ renumber_sampled_edgelist(
new_vertices.shrink_to_fit(handle.get_stream());
d_tmp_storage.shrink_to_fit(handle.get_stream());

rmm::device_uvector<label_index_t> edgelist_label_indices(edgelist_srcs.size(),
handle.get_stream());
thrust::transform(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(edgelist_srcs.size()),
edgelist_label_indices.begin(),
[offsets = raft::device_span<size_t const>(
std::get<1>(*label_offsets).data() + 1,
std::get<1>(*label_offsets).size() - 1)] __device__(size_t i) {
return static_cast<size_t>(thrust::distance(
offsets.begin(), thrust::upper_bound(thrust::seq, offsets.begin(), offsets.end(), i)));
});
auto edgelist_label_indices = detail::expand_sparse_offsets(
std::get<1>(*label_offsets), label_index_t{0}, handle.get_stream());

auto pair_first =
thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_label_indices.begin());
Expand Down

0 comments on commit d4d4c78

Please sign in to comment.