Skip to content

Commit

Permalink
[mlir][sparse] fix bugs when generate sparse conv_3d kernels. (llvm#7…
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu authored Dec 6, 2023
1 parent 861600f commit 78e2b74
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
19 changes: 5 additions & 14 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1454,28 +1454,19 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
Location loc, TensorId tid,
Level rootLvl, Value fcnt) {

auto stt = getSparseTensorType(tensors[tid]);

// Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
// level (but not resolved). Since we forward an iterator at higher level of
// the tree, the subtree need to be pruned.
// sparse levels (but not resolved). Since we forward an iterator at higher
// level of the tree, the subtree need to be pruned.
Level leafLvl = rootLvl + 1;
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
depFullyReduced(tid, leafLvl)) {
while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
!stt.isDenseLvl(leafLvl)) {
leafLvl++;
}

Level curLvl = rootLvl + 1;
// Prunes all denses subtree.
while (curLvl < leafLvl && isDenseLT(lvlTypes[tid][curLvl])) {
// One step forward in parent level results in forwarding `slice.size` step
// in child dense level.
auto [size, stride] = sliceMeta[tid][curLvl].back();
assert(stride == 1 && "Not yet implemented");
fcnt = MULI(size, fcnt);
curLvl++;
}

Value nxPosPtr = nullptr;
if (curLvl < leafLvl) {
assert(!isDenseLT(lvlTypes[tid][curLvl]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@
map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed)
}>

#DDC = #sparse_tensor.encoding<{
#DCC = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed)
}>

#DDC = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
}>

// Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
%buf = tensor.empty(%s1, %s2, %s3) : tensor<?x?x?xf32>
Expand Down Expand Up @@ -74,6 +78,15 @@ func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>)
return %ret : tensor<?x?x?xf32, #CDC>
}

func.func @conv_3d_DCC(%arg0: tensor<?x?x?xf32, #DCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DCC> {
%c6 = arith.constant 6 : index
%s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DCC>
%ret = linalg.conv_3d
ins (%arg0, %arg1: tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>)
outs (%s: tensor<?x?x?xf32, #DCC>) -> tensor<?x?x?xf32, #DCC>
return %ret : tensor<?x?x?xf32, #DCC>
}

func.func @conv_3d_DDC(%arg0: tensor<?x?x?xf32, #DDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DDC> {
%c6 = arith.constant 6 : index
%s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DDC>
Expand Down Expand Up @@ -102,12 +115,15 @@ func.func @entry() {
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
%in3D_CDC = sparse_tensor.convert %in3D
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
%in3D_DCC = sparse_tensor.convert %in3D
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #DCC>
%in3D_DDC = sparse_tensor.convert %in3D
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #DDC>

%dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
%CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
%CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
%DCC_ret = call @conv_3d_DCC(%in3D_DCC, %filter3D) : (tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DCC>)
%DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DDC>)

// CHECK:( ( ( 108, 108, 108, 108, 108, 108 ),
Expand Down Expand Up @@ -276,6 +292,48 @@ func.func @entry() {
: tensor<?x?x?xf32>, vector<6x6x6xf32>
vector.print %v2 : vector<6x6x6xf32>

// CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) )
%4 = sparse_tensor.convert %DCC_ret
: tensor<?x?x?xf32, #DCC> to tensor<?x?x?xf32>
%v4 = vector.transfer_read %3[%c0, %c0, %c0], %zero
: tensor<?x?x?xf32>, vector<6x6x6xf32>
vector.print %v2 : vector<6x6x6xf32>

// Free the resources
bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
Expand All @@ -284,9 +342,11 @@ func.func @entry() {
bufferization.dealloc_tensor %in3D_CDC : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %in3D_DDC : tensor<?x?x?xf32, #DDC>
bufferization.dealloc_tensor %in3D_DCC : tensor<?x?x?xf32, #DCC>

bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %DDC_ret : tensor<?x?x?xf32, #DDC>
bufferization.dealloc_tensor %DCC_ret : tensor<?x?x?xf32, #DCC>
return
}

0 comments on commit 78e2b74

Please sign in to comment.