Skip to content

Commit

Permalink
update (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Apr 15, 2024
1 parent f2d9919 commit 49e0fe8
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion csrc/cuda/fps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1);
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/graclus_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
cudaSetDevice(rowptr.get_device());
c10::cuda::MaybeSetDevice(rowptr.get_device());

int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/grid_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
cudaSetDevice(pos.get_device());
c10::cuda::MaybeSetDevice(pos.get_device());

if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/knn_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,

CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());

cudaSetDevice(x.get_device());
c10::cuda::MaybeSetDevice(x.get_device());

auto row = torch::empty({y.size(0) * k}, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/nearest_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
c10::cuda::MaybeSetDevice(x.get_device());

x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
Expand Down
4 changes: 1 addition & 3 deletions csrc/cuda/radius_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));

cudaSetDevice(x.get_device());
c10::cuda::MaybeSetDevice(x.get_device());

if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
Expand All @@ -70,8 +70,6 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,

CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());

cudaSetDevice(x.get_device());

auto row =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
auto col =
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/rw_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
c10::cuda::MaybeSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
Expand Down

0 comments on commit 49e0fe8

Please sign in to comment.