Skip to content

Commit

Permalink
refactor _to_sparse_semi_structure()
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Oct 3, 2024
1 parent ad02e60 commit b90a8d3
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -911,20 +911,20 @@ _to_sparse_semi_structured(const Tensor& dense) {
#if defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR(__func__, " : CUTLASS not supported");
return std::make_tuple(Tensor{}, Tensor{});
#else
#elif defined(USE_ROCM)
// Check dimensions of the dense matrix.
TORCH_CHECK(dense.dim() == 2,
__func__, " : Expected dense argument to be 2D tensor, got ",
dense.dim(), " dims");

#if defined(USE_ROCM)
// Generate sparse tensor using cuSPARSELt compression
auto sparse = torch._cslt_compress(dense);

// Extract the compressed data and metadata
auto compressed_data = sparse.values();
auto metadata = sparse.indices();

// Print to console that we are using hipSPARSELt
printf("Using hipSPARSELt for sparse semi-structured conversion\n");
return std::make_tuple(compressed_data, metadata);

#else // Determine PyTorch datatype for the metadata matrix.
Expand Down Expand Up @@ -1052,7 +1052,6 @@ _to_sparse_semi_structured(const Tensor& dense) {
return std::make_tuple(sparse_cpu.to(dense.device()),
meta_reordered_cpu.to(dense.device()));
#endif
#endif
}

} // namespace at::native

0 comments on commit b90a8d3

Please sign in to comment.