diff --git a/src/index_notation/kernel.cpp b/src/index_notation/kernel.cpp index f6235609a..c61778ee6 100644 --- a/src/index_notation/kernel.cpp +++ b/src/index_notation/kernel.cpp @@ -65,8 +65,10 @@ void unpackResults(size_t numResults, const vector arguments, num *= ((int*)tensorData->indices[i][0])[0]; } else if (modeType.getName() == Sparse.getName()) { auto size = ((int*)tensorData->indices[i][0])[num]; - Array pos = Array(type(), tensorData->indices[i][0], num+1, Array::UserOwns); - Array idx = Array(type(), tensorData->indices[i][1], size, Array::UserOwns); + Array pos = Array(type(), tensorData->indices[i][0], + num+1, Array::UserOwns); + Array idx = Array(type(), tensorData->indices[i][1], + size, Array::UserOwns); modeIndices.push_back(ModeIndex({pos, idx})); num = size; } else { diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index fa62df2e4..db7ef4804 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -644,19 +644,33 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) { } TensorVar A = Aaccess.getTensorVar(); + if (A.getFormat().getModeFormats()[0].getName() != "dense" || + A.getFormat().getModeFormats()[1].getName() != "compressed" || + A.getFormat().getModeOrdering()[0] != 0 || + A.getFormat().getModeOrdering()[1] != 1) { + return stmt; + } + TensorVar B = Baccess.getTensorVar(); - TensorVar C = Caccess.getTensorVar(); + if (B.getFormat().getModeFormats()[0].getName() != "dense" || + B.getFormat().getModeFormats()[1].getName() != "compressed" || + B.getFormat().getModeOrdering()[0] != 0 || + B.getFormat().getModeOrdering()[1] != 1) { + return stmt; + } - if (A.getFormat() != CSR || - B.getFormat() != CSR || - C.getFormat() != CSR) { + TensorVar C = Caccess.getTensorVar(); + if (C.getFormat().getModeFormats()[0].getName() != "dense" || + C.getFormat().getModeFormats()[1].getName() != "compressed" || + C.getFormat().getModeOrdering()[0] != 0 || + C.getFormat().getModeOrdering()[1] != 1) { return stmt; } // It's an SpMM statement so return an optimized SpMM statement TensorVar w("w", Type(Float64, {A.getType().getShape().getDimension(1)}), - dense); + taco::dense); return forall(i, where(forall(j, A(i,j) = w(j)), diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index a2a49f065..6798035c8 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -102,8 +102,9 @@ static bool hasStores(Stmt stmt) { return stmt.defined() && FindStores().hasStores(stmt); } -Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, - bool compute) { +Stmt +LowererImpl::lower(IndexStmt stmt, string name, bool assemble, bool compute) +{ this->assemble = assemble; this->compute = compute; @@ -125,7 +126,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, // Create iterators iterators = Iterators::make(stmt, tensorVars, &indexVars); - vector inputAccesses, resultAccesses; + vector inputAccesses, resultAccesses; set reducedAccesses; inputAccesses = getArgumentAccesses(stmt); std::tie(resultAccesses, reducedAccesses) = getResultAccesses(stmt); diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 9017bab48..77be414af 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -45,10 +45,10 @@ using taco::error::expr_transposition; #include "taco/lower/mode_format_dense.h" taco::ModeFormat dense(std::make_shared()); -static const Dimension n, m, o; +static const Dimension n; static const Type vectype(Float64, {n}); -static const Type mattype(Float64, {n,m}); -static const Type tentype(Float64, {n,m,o}); +static const Type mattype(Float64, {n,n}); +static const Type tentype(Float64, {n,n,n}); static TensorVar alpha("alpha", Float64); static TensorVar beta("beta", Float64); @@ -282,17 +282,14 @@ TEST_P(lower, compile) { } { - SCOPED_TRACE("Separate Assembly and Compute\n" + - toString(taco::lower(stmt,"assemble",true,false)) + "\n" + - toString(taco::lower(stmt,"compute",false,true))); + SCOPED_TRACE("Separate Assembly and Compute\n"); ASSERT_TRUE(kernel.assemble(arguments)); ASSERT_TRUE(kernel.compute(arguments)); verifyResults(results, arguments, varsFormatted, expected); } { - SCOPED_TRACE("Fused Assembly and Compute\n" + - toString(taco::lower(stmt,"evaluate",true,true))); + SCOPED_TRACE("Fused Assembly and Compute\n"); ASSERT_TRUE(kernel(arguments)); verifyResults(results, arguments, varsFormatted, expected); } @@ -734,8 +731,8 @@ TEST_STMT(DISABLED_where_spmm, forall(j, w(j) += B(i,k) * C(k,j))))), Values( - Formats({{A,Format({dense,dense})}, - {B,Format({dense,dense})}, {C,Format({dense,dense})}}), +// Formats({{A,Format({dense,dense})}, +// {B,Format({dense,dense})}, {C,Format({dense,dense})}}), Formats({{A,Format({dense,sparse})}, {B,Format({dense,sparse})}, {C,Format({dense,sparse})}}) ), @@ -1584,4 +1581,3 @@ TEST_STMT(vector_not, {{a, {{{1}, 1.0}, {{2}, 1.0}, {{3}, 1.0}}}}) } ) - diff --git a/tools/taco.cpp b/tools/taco.cpp index 51a8bee60..568a31f4b 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -811,15 +811,17 @@ int main(int argc, char* argv[]) { else { if (newLower) { IndexStmt stmt = makeConcrete(tensor.getAssignment()); - if (printConcrete) { - cout << stmt << endl; - } string reason; stmt = reorderLoopsTopologically(stmt); stmt = insertTemporaries(stmt); taco_uassert(stmt != IndexStmt()) << reason; stmt = parallelizeOuterLoop(stmt); + + if (printConcrete) { + cout << stmt << endl; + } + compute = lower(stmt, "compute", false, true); assemble = lower(stmt, "assemble", true, false); evaluate = lower(stmt, "evaluate", true, true);