Skip to content

Commit

Permalink
support for torch.bfloat16 in radius ops
Browse files Browse the repository at this point in the history
  • Loading branch information
viktortnk committed Feb 4, 2024
1 parent ef79a92 commit c1f12a2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
16 changes: 9 additions & 7 deletions csrc/cuda/radius_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,

auto stream = at::cuda::getCurrentCUDAStream();
auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "_", [&] {
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(),
ptr_y.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), r * r, x.size(0), y.size(0), x.size(1),
ptr_x.value().numel() - 1, max_num_neighbors);
});

auto mask = row != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
Expand Down
6 changes: 3 additions & 3 deletions test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import scipy.spatial
import torch
from torch_cluster import radius, radius_graph
from torch_cluster.testing import devices, grad_dtypes, tensor
from torch_cluster.testing import devices, floating_dtypes, tensor


def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
[-1, -1],
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_radius(dtype, device):
(1, 6)])


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
def test_radius_graph(dtype, device):
x = tensor([
[-1, -1],
Expand Down
1 change: 1 addition & 0 deletions torch_cluster/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
torch.long
]
grad_dtypes = [torch.half, torch.float, torch.double]
floating_dtypes = grad_dtypes + [torch.bfloat16]

devices = [torch.device('cpu')]
if torch.cuda.is_available():
Expand Down

0 comments on commit c1f12a2

Please sign in to comment.