diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index 6a2d7c33356f9c..ff8561534a3768 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -1454,28 +1454,19 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc, TensorId tid, Level rootLvl, Value fcnt) { + auto stt = getSparseTensorType(tensors[tid]); // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced - // level (but not resolved). Since we forward an iterator at higher level of - // the tree, the subtree need to be pruned. + // sparse levels (but not resolved). Since we forward an iterator at higher + // level of the tree, the subtree need to be pruned. Level leafLvl = rootLvl + 1; - while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() && - depFullyReduced(tid, leafLvl)) { + while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) && + !stt.isDenseLvl(leafLvl)) { leafLvl++; } Level curLvl = rootLvl + 1; - // Prunes all denses subtree. - while (curLvl < leafLvl && isDenseLT(lvlTypes[tid][curLvl])) { - // One step forward in parent level results in forwarding `slice.size` step - // in child dense level. - auto [size, stride] = sliceMeta[tid][curLvl].back(); - assert(stride == 1 && "Not yet implemented"); - fcnt = MULI(size, fcnt); - curLvl++; - } - Value nxPosPtr = nullptr; if (curLvl < leafLvl) { assert(!isDenseLT(lvlTypes[tid][curLvl])); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir index dfb1bb71a68c41..451d2b87694614 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir @@ -38,10 +38,14 @@ map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed) }> -#DDC = #sparse_tensor.encoding<{ +#DCC = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed) }> +#DDC = #sparse_tensor.encoding<{ + map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed) +}> + // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor { %buf = tensor.empty(%s1, %s2, %s3) : tensor @@ -74,6 +78,15 @@ func.func @conv_3d_CDC(%arg0: tensor, %arg1: tensor) return %ret : tensor } +func.func @conv_3d_DCC(%arg0: tensor, %arg1: tensor) -> tensor { + %c6 = arith.constant 6 : index + %s = tensor.empty(%c6, %c6, %c6) : tensor + %ret = linalg.conv_3d + ins (%arg0, %arg1: tensor, tensor) + outs (%s: tensor) -> tensor + return %ret : tensor +} + func.func @conv_3d_DDC(%arg0: tensor, %arg1: tensor) -> tensor { %c6 = arith.constant 6 : index %s = tensor.empty(%c6, %c6, %c6) : tensor @@ -102,12 +115,15 @@ func.func @entry() { : tensor to tensor %in3D_CDC = sparse_tensor.convert %in3D : tensor to tensor + %in3D_DCC = sparse_tensor.convert %in3D + : tensor to tensor %in3D_DDC = sparse_tensor.convert %in3D : tensor to tensor %dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor, tensor, tensor) -> (tensor) %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor, tensor) -> (tensor) %CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor, tensor) -> (tensor) + %DCC_ret = call @conv_3d_DCC(%in3D_DCC, %filter3D) : (tensor, tensor) -> (tensor) %DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor, tensor) -> (tensor) // CHECK:( ( ( 108, 108, 108, 108, 108, 108 ), @@ -276,6 +292,48 @@ func.func @entry() { : tensor, vector<6x6x6xf32> vector.print %v2 : vector<6x6x6xf32> + // CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ) + %4 = sparse_tensor.convert %DCC_ret + : tensor to tensor + %v4 = vector.transfer_read %3[%c0, %c0, %c0], %zero + : tensor, vector<6x6x6xf32> + vector.print %v2 : vector<6x6x6xf32> + // Free the resources bufferization.dealloc_tensor %in3D : tensor bufferization.dealloc_tensor %filter3D : tensor @@ -284,9 +342,11 @@ func.func @entry() { bufferization.dealloc_tensor %in3D_CDC : tensor bufferization.dealloc_tensor %in3D_CCC : tensor bufferization.dealloc_tensor %in3D_DDC : tensor + bufferization.dealloc_tensor %in3D_DCC : tensor bufferization.dealloc_tensor %CCC_ret : tensor bufferization.dealloc_tensor %CDC_ret : tensor bufferization.dealloc_tensor %DDC_ret : tensor + bufferization.dealloc_tensor %DCC_ret : tensor return }