diff --git a/csrc/cuda/fps_cuda.cu b/csrc/cuda/fps_cuda.cu index dd3671a..38195fc 100644 --- a/csrc/cuda/fps_cuda.cu +++ b/csrc/cuda/fps_cuda.cu @@ -71,7 +71,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, CHECK_CUDA(ptr); CHECK_CUDA(ratio); CHECK_INPUT(ptr.dim() == 1); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); src = src.view({src.size(0), -1}).contiguous(); ptr = ptr.contiguous(); diff --git a/csrc/cuda/graclus_cuda.cu b/csrc/cuda/graclus_cuda.cu index 61e7d70..3bb118b 100644 --- a/csrc/cuda/graclus_cuda.cu +++ b/csrc/cuda/graclus_cuda.cu @@ -223,7 +223,7 @@ torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col, CHECK_INPUT(optional_weight.value().dim() == 1); CHECK_INPUT(optional_weight.value().numel() == col.numel()); } - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); int64_t num_nodes = rowptr.numel() - 1; auto out = torch::full(num_nodes, -1, rowptr.options()); diff --git a/csrc/cuda/grid_cuda.cu b/csrc/cuda/grid_cuda.cu index 8696b9f..64037bd 100644 --- a/csrc/cuda/grid_cuda.cu +++ b/csrc/cuda/grid_cuda.cu @@ -29,7 +29,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size, torch::optional optional_end) { CHECK_CUDA(pos); CHECK_CUDA(size); - cudaSetDevice(pos.get_device()); + c10::cuda::MaybeSetDevice(pos.get_device()); if (optional_start.has_value()) CHECK_CUDA(optional_start.value()); diff --git a/csrc/cuda/knn_cuda.cu b/csrc/cuda/knn_cuda.cu index caa5c96..c4dac2a 100644 --- a/csrc/cuda/knn_cuda.cu +++ b/csrc/cuda/knn_cuda.cu @@ -113,7 +113,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); auto row = torch::empty({y.size(0) * k}, ptr_y.value().options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options()); diff --git a/csrc/cuda/nearest_cuda.cu b/csrc/cuda/nearest_cuda.cu index 81eef92..7a3458e 100644 --- a/csrc/cuda/nearest_cuda.cu +++ b/csrc/cuda/nearest_cuda.cu @@ -71,7 +71,7 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y, CHECK_CUDA(y); CHECK_CUDA(ptr_x); CHECK_CUDA(ptr_y); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); x = x.view({x.size(0), -1}).contiguous(); y = y.view({y.size(0), -1}).contiguous(); diff --git a/csrc/cuda/radius_cuda.cu b/csrc/cuda/radius_cuda.cu index 7340d70..7efb2ff 100644 --- a/csrc/cuda/radius_cuda.cu +++ b/csrc/cuda/radius_cuda.cu @@ -52,7 +52,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(y.dim() == 2); CHECK_INPUT(x.size(1) == y.size(1)); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); if (ptr_x.has_value()) { CHECK_CUDA(ptr_x.value()); @@ -70,8 +70,6 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); - cudaSetDevice(x.get_device()); - auto row = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); auto col = diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index 763b861..7ecd4fc 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -121,7 +121,7 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, CHECK_CUDA(rowptr); CHECK_CUDA(col); CHECK_CUDA(start); - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); CHECK_INPUT(rowptr.dim() == 1); CHECK_INPUT(col.dim() == 1);