Skip to content

Commit

Permalink
Improve detail namespace usage, fix documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Nov 5, 2024
1 parent b870aac commit 288ea20
Show file tree
Hide file tree
Showing 19 changed files with 101 additions and 75 deletions.
20 changes: 19 additions & 1 deletion cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,25 @@ namespace linalg {
namespace detail {

/**
* @brief create a cuSparse dense descriptor
* @brief create a cuSparse dense descriptor for a vector
* @tparam ValueType Data type of vector_view (float/double)
* @tparam IndexType Type of vector_view
* @param[in] vector_view input raft::device_vector_view
* @returns dense vector descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndexType>
cusparseDnVecDescr_t create_descriptor(raft::device_vector_view<ValueType, IndexType> vector_view)
{
cusparseDnVecDescr_t descr;
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednvec(
&descr,
vector_view.extent(0),
const_cast<std::remove_const_t<ValueType>*>(vector_view.data_handle())));
return descr;
}

/**
* @brief create a cuSparse dense descriptor for a matrix
* @tparam ValueType Data type of dense_view (float/double)
* @tparam IndexType Type of dense_view
* @tparam LayoutPolicy layout of dense_view
Expand Down
97 changes: 36 additions & 61 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include <raft/matrix/triangular.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/cusparse_wrappers.h>
#include <raft/sparse/linalg/detail/cusparse_utils.hpp>
#include <raft/sparse/solver/lanczos_types.hpp>
#include <raft/spectral/detail/lapack.hpp>
#include <raft/spectral/detail/warn_dbg.hpp>
Expand Down Expand Up @@ -1553,26 +1554,18 @@ void lanczos_aux(raft::resources const& handle,
{
auto stream = resource::get_cuda_stream(handle);

auto A_structure = A.structure_view();
IndexTypeT n = A_structure.get_n_rows();
IndexTypeT n = A.structure_view().get_n_rows();
auto v_vector = raft::make_device_vector_view<const ValueTypeT>(v.data_handle(), n);
auto u_vector = raft::make_device_vector_view<const ValueTypeT>(u.data_handle(), n);

raft::copy(
v.data_handle(), V.data_handle() + start_idx * V.stride(0), n, stream); // V(start_idx, 0)

auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A;
raft::sparse::detail::cusparsecreatecsr(&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<IndexTypeT*>(A_structure.get_indptr().data()),
const_cast<IndexTypeT*>(A_structure.get_indices().data()),
const_cast<ValueTypeT*>(A.get_elements().data()));

cusparseDnVecDescr_t cusparse_v;
cusparseDnVecDescr_t cusparse_u;
raft::sparse::detail::cusparsecreatednvec(&cusparse_v, n, v.data_handle());
raft::sparse::detail::cusparsecreatednvec(&cusparse_u, n, u.data_handle());
auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A);

cusparseDnVecDescr_t cusparse_v = raft::sparse::linalg::detail::create_descriptor(v_vector);
cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector);

ValueTypeT one = 1;
ValueTypeT zero = 0;
Expand Down Expand Up @@ -1603,8 +1596,6 @@ void lanczos_aux(raft::resources const& handle,

auto alpha_i =
raft::make_device_scalar_view(alpha.data_handle() + i * alpha.stride(1)); // alpha(0, i)
auto v_vector = raft::make_device_vector_view<const ValueTypeT>(v.data_handle(), n);
auto u_vector = raft::make_device_vector_view<const ValueTypeT>(u.data_handle(), n);
raft::linalg::dot(handle, v_vector, u_vector, alpha_i);

raft::matrix::fill(handle, vv, zero);
Expand Down Expand Up @@ -1706,17 +1697,17 @@ auto lanczos_smallest(
ValueTypeT* v0,
uint64_t seed) -> int
{
auto A_structure = A.structure_view();
int n = A_structure.get_n_rows();
int ncv = restartIter;
auto stream = resource::get_cuda_stream(handle);
int n = A.structure_view().get_n_rows();
int ncv = restartIter;
auto stream = resource::get_cuda_stream(handle);

auto V = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, ncv, n);
auto V_0_view =
raft::make_device_matrix_view<ValueTypeT, uint32_t>(V.data_handle(), 1, n); // First Row V[0]
auto v0_view = raft::make_device_matrix_view<const ValueTypeT, uint32_t>(v0, 1, n);

auto u = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, n);
auto u = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, n);
auto u_vector = raft::make_device_vector_view<ValueTypeT, uint32_t>(u.data_handle(), n);
raft::copy(u.data_handle(), v0, n, stream);

auto cublas_h = resource::get_cublas_handle(handle);
Expand Down Expand Up @@ -1835,7 +1826,7 @@ auto lanczos_smallest(
ValueTypeT one = 1;
ValueTypeT mone = -1;

// Using raft::linalg::gemv leads to Reason=7:CUBLAS_STATUS_INVALID_VALUE
// Using raft::linalg::gemv leads to Reason=7:CUBLAS_STATUS_INVALID_VALUE (issue raft#2484)
raft::linalg::detail::cublasgemv(cublas_h,
CUBLAS_OP_T,
nEigVecs,
Expand Down Expand Up @@ -1866,41 +1857,28 @@ auto lanczos_smallest(

auto V_0_view =
raft::make_device_matrix_view<ValueTypeT>(V.data_handle() + (nEigVecs * n), 1, n);

auto unrm = raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
auto input = raft::make_device_matrix_view<const ValueTypeT, uint32_t>(u.data_handle(), 1, n);
auto V_0_view_vector =
raft::make_device_vector_view<ValueTypeT, uint32_t>(V_0_view.data_handle(), n);
auto unrm = raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
raft::linalg::norm(handle,
input,
raft::make_const_mdspan(u.view()),
unrm.view(),
raft::linalg::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op());

auto u_vector_const = raft::make_device_vector_view<const ValueTypeT>(u.data_handle(), n);

raft::linalg::unary_op(
handle, u_vector_const, V_0_view, [device_scalar = unrm.data_handle()] __device__(auto y) {
return y / *device_scalar;
});
handle,
raft::make_const_mdspan(u_vector),
V_0_view,
[device_scalar = unrm.data_handle()] __device__(auto y) { return y / *device_scalar; });

auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A);

auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A;
// input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
// input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
// input_config.a_data = const_cast<ElementType*>(x.get_elements().data());
raft::sparse::detail::cusparsecreatecsr(
&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<IndexTypeT*>(A_structure.get_indptr().data()),
const_cast<IndexTypeT*>(A_structure.get_indices().data()),
const_cast<ValueTypeT*>(A.get_elements().data()));

cusparseDnVecDescr_t cusparse_v;
cusparseDnVecDescr_t cusparse_u;
raft::sparse::detail::cusparsecreatednvec(&cusparse_v, n, V_0_view.data_handle());
raft::sparse::detail::cusparsecreatednvec(&cusparse_u, n, u.data_handle());
cusparseDnVecDescr_t cusparse_v =
raft::sparse::linalg::detail::create_descriptor(V_0_view_vector);
cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector);

ValueTypeT zero = 0;
size_t bufferSize;
Expand Down Expand Up @@ -1928,22 +1906,20 @@ auto lanczos_smallest(
stream);

auto alpha_k = raft::make_device_scalar_view<ValueTypeT>(alpha.data_handle() + nEigVecs);
auto V_0_view_vector = raft::make_device_vector_view<ValueTypeT>(V_0_view.data_handle(), n);
auto u_view_vector = raft::make_device_vector_view<ValueTypeT>(u.data_handle(), n);

raft::linalg::dot(
handle, make_const_mdspan(V_0_view_vector), make_const_mdspan(u_view_vector), alpha_k);
handle, make_const_mdspan(V_0_view_vector), make_const_mdspan(u_vector), alpha_k);

raft::linalg::binary_op(handle,
make_const_mdspan(u_view_vector),
make_const_mdspan(u_vector),
make_const_mdspan(V_0_view_vector),
u_view_vector,
u_vector,
[device_scalar_ptr = alpha_k.data_handle()] __device__(
ValueTypeT u_element, ValueTypeT V_0_element) {
return u_element - (*device_scalar_ptr) * V_0_element;
});

auto temp = raft::make_device_vector<ValueTypeT>(handle, n);
auto temp = raft::make_device_vector<ValueTypeT, uint32_t>(handle, n);

auto V_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
V.data_handle(), nEigVecs, n);
Expand Down Expand Up @@ -1994,9 +1970,9 @@ auto lanczos_smallest(

auto one_scalar = raft::make_device_scalar<ValueTypeT>(handle, 1);
raft::linalg::binary_op(handle,
make_const_mdspan(u_view_vector),
make_const_mdspan(u_vector),
make_const_mdspan(temp.view()),
u_view_vector,
u_vector,
[device_scalar_ptr = one_scalar.data_handle()] __device__(
ValueTypeT u_element, ValueTypeT temp_element) {
return u_element - (*device_scalar_ptr) * temp_element;
Expand All @@ -2013,11 +1989,10 @@ auto lanczos_smallest(

auto V_kplus1 =
raft::make_device_vector_view<ValueTypeT>(V.data_handle() + V.stride(0) * (nEigVecs + 1), n);
auto u_vector = raft::make_device_vector_view<const ValueTypeT>(u.data_handle(), n);

raft::linalg::unary_op(
handle,
u_vector,
make_const_mdspan(u_vector),
V_kplus1,
[device_scalar = (beta.data_handle() + beta.stride(1) * nEigVecs)] __device__(auto y) {
return y / *device_scalar;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/solver/lanczos_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft_runtime/solver/lanczos.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ add_subdirectory(pylibraft/distance)
add_subdirectory(pylibraft/matrix)
add_subdirectory(pylibraft/neighbors)
add_subdirectory(pylibraft/random)
add_subdirectory(pylibraft/solver)
add_subdirectory(pylibraft/sparse)
add_subdirectory(pylibraft/cluster)

if(DEFINED cython_lib_dir)
Expand Down
15 changes: 15 additions & 0 deletions python/pylibraft/pylibraft/sparse/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# =============================================================================
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

add_subdirectory(linalg)
18 changes: 18 additions & 0 deletions python/pylibraft/pylibraft/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pylibraft.sparse import linalg

__all__ = ["linalg"]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand All @@ -23,5 +23,5 @@ set(linked_libraries raft::raft raft::compiled)
rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_sources}"
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX solver_
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX sparse_
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2024-2024, NVIDIA CORPORATION.
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
from cupyx.scipy import sparse

from pylibraft.solver import eigsh
from pylibraft.sparse.linalg import eigsh


def shaped_random(
Expand Down

0 comments on commit 288ea20

Please sign in to comment.