Skip to content

Commit 97edc84

Browse files
authored
Merge pull request #475 from tensor-compiler/multidim-workspace
Multidimensional dense workspaces
2 parents 9f7b71c + 007763e commit 97edc84

15 files changed

+1167
-154
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,10 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
549549
/// Takes any index notation and concretizes unknowns to make it concrete notation
550550
IndexStmt concretize() const;
551551

552+
/// Takes any index notation and concretizes unknowns to make it concrete notation
553+
/// given a Provenance Graph of indexVars
554+
IndexStmt concretizeScheduled(ProvenanceGraph provGraph, std::vector<IndexVar> forallIndexVarList) const;
555+
552556
/// The \code{split} transformation splits (strip-mines) an index
553557
/// variable into two nested index variables, where the size of the
554558
/// inner index variable is constant. The size of the outer index
@@ -681,6 +685,12 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
681685
/// reorder computations to increase locality
682686
IndexStmt precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) const;
683687

688+
/// The precompute transformation is described in kjolstad2019
689+
/// allows us to leverage scratchpad memories and
690+
/// reorder computations to increase locality
691+
IndexStmt precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
692+
std::vector<IndexVar> iw_vars, TensorVar workspace) const;
693+
684694
/// bound specifies a compile-time constraint on an index variable's
685695
/// iteration space that allows knowledge of the
686696
/// size or structured sparsity pattern of the inputs to be
@@ -1119,6 +1129,10 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
11191129
/// notation is printed to.
11201130
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
11211131

1132+
/// Check whether the statement is in the reduction index notation dialect
1133+
/// given a schedule described by the Provenance Graph
1134+
bool isReductionNotationScheduled(IndexStmt, ProvenanceGraph, std::string* reason=nullptr);
1135+
11221136
/// Check whether the statement is in the concrete index notation dialect.
11231137
/// This means every index variable has a forall node, there are no reduction
11241138
/// nodes, and that every reduction variable use is nested inside a compound
@@ -1136,6 +1150,18 @@ IndexStmt makeReductionNotation(IndexStmt);
11361150
/// as needed.
11371151
IndexStmt makeConcreteNotation(IndexStmt);
11381152

1153+
1154+
/// Convert einsum notation to reduction notation, by applying Einstein's
1155+
/// summation convention to sum non-free/reduction variables over their term
1156+
/// taking into account a schedule given by the Provenance Graph.
1157+
Assignment makeReductionNotationScheduled(Assignment, ProvenanceGraph);
1158+
IndexStmt makeReductionNotationScheduled(IndexStmt, ProvenanceGraph);
1159+
1160+
/// Convert reduction notation to concrete notation, by inserting forall nodes,
1161+
/// replacing reduction nodes by compound assignments, and inserting temporaries
1162+
/// as needed while taking into account a schedule given by the Provenance Graph.
1163+
IndexStmt makeConcreteNotationScheduled(IndexStmt, ProvenanceGraph, std::vector<IndexVar> forallIndexVars);
1164+
11391165
/// Returns the results of the index statement, in the order they appear.
11401166
std::vector<TensorVar> getResults(IndexStmt stmt);
11411167

include/taco/index_notation/transformations.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ class Precompute : public TransformationInterface {
8989
public:
9090
Precompute();
9191
Precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace);
92-
92+
Precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
93+
std::vector<IndexVar> iw_vars, TensorVar workspace);
94+
9395
IndexExpr getExpr() const;
94-
IndexVar geti() const;
95-
IndexVar getiw() const;
96+
std::vector<IndexVar>& getIVars() const;
97+
std::vector<IndexVar>& getIWVars() const;
9698
TensorVar getWorkspace() const;
9799

98100
/// Apply the precompute optimization to a concrete index statement.

include/taco/ir/workspace_rewriter.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TACO_WORKSPACE_REWRITER_H
2+
#define TACO_WORKSPACE_REWRITER_H
3+
4+
#include <vector>
5+
#include <map>
6+
7+
8+
namespace taco {
9+
class TensorVar;
10+
11+
namespace ir {
12+
class Stmt;
13+
class Expr;
14+
}
15+
16+
/// Rewrite a post-lowered IR statement to take into account multidimensional temporaries.
17+
/// Replaces Dimension GetProperty nodes that correspond to temporary workspaces with
18+
/// their corresponding dimension found in the temporarySizeMap.
19+
ir::Stmt rewriteTemporaryGP(const ir::Stmt& stmt, std::vector<TensorVar> whereTemps,
20+
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap);
21+
22+
}
23+
#endif

include/taco/lower/lowerer_impl_imperative.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,12 @@ class LowererImplImperative : public LowererImpl {
494494
std::vector<TensorVar> whereTemps;
495495
std::map<TensorVar, const AccessNode *> whereTempsToResult;
496496

497+
// Map temporary tensorVars to a list of size expressions for each mode
498+
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap;
499+
500+
// List that contains all temporary tensorVars
501+
std::vector<TensorVar> temporaries;
502+
497503
bool captureNextLocatePos = false;
498504
ir::Stmt capturedLocatePos; // used for whereConsumer when want to replicate same locating
499505

include/taco/parser/schedule_parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ namespace parser {
1212
// [ [ "reorder", "i", "j" ], [ "precompute", "D(i,j)*E(j,k)", "j", "j_pre" ] ]
1313
std::vector<std::vector<std::string>> ScheduleParser(const std::string);
1414

15+
std::vector<std::string> varListParser(const std::string);
16+
1517
// serialize the result of a parse (for debugging)
1618
std::string serializeParsedSchedule(std::vector<std::vector<std::string>>);
1719

src/error/error_checks.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
4141
for (size_t mode = 0; mode < resultVars.size(); mode++) {
4242
IndexVar var = resultVars[mode];
4343
auto dimension = shape.getDimension(mode);
44-
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
44+
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
45+
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
46+
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
4547
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
4648
} else {
4749
indexVarDims.insert({var, dimension});
@@ -63,7 +65,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
6365
dimension = Dimension(a.getIndexSet(mode).size());
6466
}
6567

66-
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
68+
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
69+
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
70+
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
6771
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
6872
} else {
6973
indexVarDims.insert({var, dimension});

0 commit comments

Comments
 (0)