Skip to content

Commit

Permalink
Various fixes including better temporary inserts
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed May 30, 2019
1 parent f4408d9 commit f964d00
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 24 deletions.
6 changes: 4 additions & 2 deletions src/index_notation/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ void unpackResults(size_t numResults, const vector<void*> arguments,
num *= ((int*)tensorData->indices[i][0])[0];
} else if (modeType.getName() == Sparse.getName()) {
auto size = ((int*)tensorData->indices[i][0])[num];
Array pos = Array(type<int>(), tensorData->indices[i][0], num+1, Array::UserOwns);
Array idx = Array(type<int>(), tensorData->indices[i][1], size, Array::UserOwns);
Array pos = Array(type<int>(), tensorData->indices[i][0],
num+1, Array::UserOwns);
Array idx = Array(type<int>(), tensorData->indices[i][1],
size, Array::UserOwns);
modeIndices.push_back(ModeIndex({pos, idx}));
num = size;
} else {
Expand Down
24 changes: 19 additions & 5 deletions src/index_notation/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,19 +644,33 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
}

TensorVar A = Aaccess.getTensorVar();
if (A.getFormat().getModeFormats()[0].getName() != "dense" ||
A.getFormat().getModeFormats()[1].getName() != "compressed" ||
A.getFormat().getModeOrdering()[0] != 0 ||
A.getFormat().getModeOrdering()[1] != 1) {
return stmt;
}

TensorVar B = Baccess.getTensorVar();
TensorVar C = Caccess.getTensorVar();
if (B.getFormat().getModeFormats()[0].getName() != "dense" ||
B.getFormat().getModeFormats()[1].getName() != "compressed" ||
B.getFormat().getModeOrdering()[0] != 0 ||
B.getFormat().getModeOrdering()[1] != 1) {
return stmt;
}

if (A.getFormat() != CSR ||
B.getFormat() != CSR ||
C.getFormat() != CSR) {
TensorVar C = Caccess.getTensorVar();
if (C.getFormat().getModeFormats()[0].getName() != "dense" ||
C.getFormat().getModeFormats()[1].getName() != "compressed" ||
C.getFormat().getModeOrdering()[0] != 0 ||
C.getFormat().getModeOrdering()[1] != 1) {
return stmt;
}

// It's an SpMM statement so return an optimized SpMM statement
TensorVar w("w",
Type(Float64, {A.getType().getShape().getDimension(1)}),
dense);
taco::dense);
return forall(i,
where(forall(j,
A(i,j) = w(j)),
Expand Down
7 changes: 4 additions & 3 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ static bool hasStores(Stmt stmt) {
return stmt.defined() && FindStores().hasStores(stmt);
}

Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
bool compute) {
Stmt
LowererImpl::lower(IndexStmt stmt, string name, bool assemble, bool compute)
{
this->assemble = assemble;
this->compute = compute;

Expand All @@ -125,7 +126,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
// Create iterators
iterators = Iterators::make(stmt, tensorVars, &indexVars);

vector<Access> inputAccesses, resultAccesses;
vector<Access> inputAccesses, resultAccesses;
set<Access> reducedAccesses;
inputAccesses = getArgumentAccesses(stmt);
std::tie(resultAccesses, reducedAccesses) = getResultAccesses(stmt);
Expand Down
18 changes: 7 additions & 11 deletions test/tests-lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ using taco::error::expr_transposition;
#include "taco/lower/mode_format_dense.h"
taco::ModeFormat dense(std::make_shared<taco::DenseModeFormat>());

static const Dimension n, m, o;
static const Dimension n;
static const Type vectype(Float64, {n});
static const Type mattype(Float64, {n,m});
static const Type tentype(Float64, {n,m,o});
static const Type mattype(Float64, {n,n});
static const Type tentype(Float64, {n,n,n});

static TensorVar alpha("alpha", Float64);
static TensorVar beta("beta", Float64);
Expand Down Expand Up @@ -282,17 +282,14 @@ TEST_P(lower, compile) {
}

{
SCOPED_TRACE("Separate Assembly and Compute\n" +
toString(taco::lower(stmt,"assemble",true,false)) + "\n" +
toString(taco::lower(stmt,"compute",false,true)));
SCOPED_TRACE("Separate Assembly and Compute\n");
ASSERT_TRUE(kernel.assemble(arguments));
ASSERT_TRUE(kernel.compute(arguments));
verifyResults(results, arguments, varsFormatted, expected);
}

{
SCOPED_TRACE("Fused Assembly and Compute\n" +
toString(taco::lower(stmt,"evaluate",true,true)));
SCOPED_TRACE("Fused Assembly and Compute\n");
ASSERT_TRUE(kernel(arguments));
verifyResults(results, arguments, varsFormatted, expected);
}
Expand Down Expand Up @@ -734,8 +731,8 @@ TEST_STMT(DISABLED_where_spmm,
forall(j,
w(j) += B(i,k) * C(k,j))))),
Values(
Formats({{A,Format({dense,dense})},
{B,Format({dense,dense})}, {C,Format({dense,dense})}}),
// Formats({{A,Format({dense,dense})},
// {B,Format({dense,dense})}, {C,Format({dense,dense})}}),
Formats({{A,Format({dense,sparse})},
{B,Format({dense,sparse})}, {C,Format({dense,sparse})}})
),
Expand Down Expand Up @@ -1584,4 +1581,3 @@ TEST_STMT(vector_not,
{{a, {{{1}, 1.0}, {{2}, 1.0}, {{3}, 1.0}}}})
}
)

8 changes: 5 additions & 3 deletions tools/taco.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,15 +811,17 @@ int main(int argc, char* argv[]) {
else {
if (newLower) {
IndexStmt stmt = makeConcrete(tensor.getAssignment());
if (printConcrete) {
cout << stmt << endl;
}

string reason;
stmt = reorderLoopsTopologically(stmt);
stmt = insertTemporaries(stmt);
taco_uassert(stmt != IndexStmt()) << reason;
stmt = parallelizeOuterLoop(stmt);

if (printConcrete) {
cout << stmt << endl;
}

compute = lower(stmt, "compute", false, true);
assemble = lower(stmt, "assemble", true, false);
evaluate = lower(stmt, "evaluate", true, true);
Expand Down

0 comments on commit f964d00

Please sign in to comment.