Skip to content

Commit 007763e

Browse files
committed
Add in changes from PR 475 review
1 parent 9700b18 commit 007763e

File tree

3 files changed

+53
-97
lines changed

3 files changed

+53
-97
lines changed

src/index_notation/index_notation.cpp

Lines changed: 50 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "taco/tensor.h"
2727

2828
#include "taco/util/name_generator.h"
29+
#include "taco/util/scopedset.h"
2930
#include "taco/util/scopedmap.h"
3031
#include "taco/util/strings.h"
3132
#include "taco/util/collections.h"
@@ -1525,11 +1526,9 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
15251526
IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
15261527
std::vector<IndexVar> iw_vars, TensorVar workspace) const {
15271528

1528-
// TODO: need to assert they are same length
15291529
IndexStmt transformed = *this;
15301530
string reason;
15311531

1532-
// FIXME: need to re-enable this later
15331532
taco_uassert(i_vars.size() == iw_vars.size()) << "The precompute transformation requires"
15341533
<< "i_vars and iw_vars to be the same size";
15351534
for (int l = 0; l < (int) i_vars.size(); l++) {
@@ -2343,18 +2342,18 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23432342
// Reduction notation until proved otherwise
23442343
bool isReduction = true;
23452344

2346-
util::ScopedMap<IndexVar,int> boundVars; // (int) value not used
2345+
util::ScopedSet<IndexVar> boundVars;
23472346
vector<IndexVar> boundVarsList;
23482347
for (auto& var : to<Assignment>(stmt).getFreeVars()) {
2349-
boundVars.insert({var,0});
2348+
boundVars.insert({var});
23502349
boundVarsList.push_back(var);
23512350
}
23522351

23532352
match(stmt,
23542353
std::function<void(const ReductionNode*,Matcher*)>([&](
23552354
const ReductionNode* op, Matcher* ctx) {
23562355
boundVars.scope();
2357-
boundVars.insert({op->var,0});
2356+
boundVars.insert({op->var});
23582357
ctx->match(op->a);
23592358
boundVars.unscope();
23602359
}),
@@ -2386,18 +2385,18 @@ bool isReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph, std
23862385
// Reduction notation until proved otherwise
23872386
bool isReduction = true;
23882387

2389-
util::ScopedMap<IndexVar,int> boundVars; // (int) value not used
2388+
util::ScopedSet<IndexVar> boundVars;
23902389
vector<IndexVar> boundVarsList;
23912390
for (auto& var : to<Assignment>(stmt).getFreeVars()) {
2392-
boundVars.insert({var,0});
2391+
boundVars.insert({var});
23932392
boundVarsList.push_back(var);
23942393
}
23952394

23962395
match(stmt,
23972396
std::function<void(const ReductionNode*,Matcher*)>([&](
23982397
const ReductionNode* op, Matcher* ctx) {
23992398
boundVars.scope();
2400-
boundVars.insert({op->var,0});
2399+
boundVars.insert({op->var});
24012400
ctx->match(op->a);
24022401
boundVars.unscope();
24032402
}),
@@ -2441,7 +2440,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
24412440

24422441
bool inWhereProducer = false;
24432442
bool inWhereConsumer = false;
2444-
util::ScopedMap<IndexVar,int> boundVars; // (int) value not used
2443+
util::ScopedSet<IndexVar> boundVars;
24452444
std::set<IndexVar> definedVars; // used to check if all variables recoverable TODO: need to actually use scope like above
24462445

24472446
ProvenanceGraph provGraph = ProvenanceGraph(stmt);
@@ -2450,7 +2449,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
24502449
std::function<void(const ForallNode*,Matcher*)>([&](const ForallNode* op,
24512450
Matcher* ctx) {
24522451
boundVars.scope();
2453-
boundVars.insert({op->indexVar,0});
2452+
boundVars.insert({op->indexVar});
24542453
definedVars.insert(op->indexVar);
24552454
ctx->match(op->stmt);
24562455
boundVars.unscope();
@@ -2606,6 +2605,47 @@ IndexStmt makeReductionNotation(IndexStmt stmt) {
26062605
return makeReductionNotation(to<Assignment>(stmt));
26072606
}
26082607

2608+
// Replace other reductions with where and forall statements
2609+
struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2610+
using IndexNotationRewriter::visit;
2611+
2612+
Reduction reduction;
2613+
TensorVar t;
2614+
2615+
void visit(const AssignmentNode* node) {
2616+
reduction = Reduction();
2617+
t = TensorVar();
2618+
2619+
IndexExpr rhs = rewrite(node->rhs);
2620+
2621+
// nothing was rewritten
2622+
if (rhs == node->rhs) {
2623+
stmt = node;
2624+
return;
2625+
}
2626+
2627+
taco_iassert(t.defined() && reduction.defined());
2628+
IndexStmt consumer = Assignment(node->lhs, rhs, node->op);
2629+
IndexStmt producer = forall(reduction.getVar(),
2630+
Assignment(t, reduction.getExpr(),
2631+
reduction.getOp()));
2632+
stmt = where(rewrite(consumer), rewrite(producer));
2633+
}
2634+
2635+
void visit(const ReductionNode* node) {
2636+
// only rewrite one reduction at a time
2637+
if (reduction.defined()) {
2638+
expr = node;
2639+
return;
2640+
}
2641+
2642+
reduction = node;
2643+
t = TensorVar("t" + util::toString(node->var),
2644+
node->getDataType());
2645+
expr = t;
2646+
}
2647+
};
2648+
26092649
IndexStmt makeConcreteNotation(IndexStmt stmt) {
26102650
std::string reason;
26112651
taco_iassert(isReductionNotation(stmt, &reason))
@@ -2646,46 +2686,6 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) {
26462686
stmt = forall(i, stmt);
26472687
}
26482688

2649-
// Replace other reductions with where and forall statements
2650-
struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2651-
using IndexNotationRewriter::visit;
2652-
2653-
Reduction reduction;
2654-
TensorVar t;
2655-
2656-
void visit(const AssignmentNode* node) {
2657-
reduction = Reduction();
2658-
t = TensorVar();
2659-
2660-
IndexExpr rhs = rewrite(node->rhs);
2661-
2662-
// nothing was rewritten
2663-
if (rhs == node->rhs) {
2664-
stmt = node;
2665-
return;
2666-
}
2667-
2668-
taco_iassert(t.defined() && reduction.defined());
2669-
IndexStmt consumer = Assignment(node->lhs, rhs, node->op);
2670-
IndexStmt producer = forall(reduction.getVar(),
2671-
Assignment(t, reduction.getExpr(),
2672-
reduction.getOp()));
2673-
stmt = where(rewrite(consumer), rewrite(producer));
2674-
}
2675-
2676-
void visit(const ReductionNode* node) {
2677-
// only rewrite one reduction at a time
2678-
if (reduction.defined()) {
2679-
expr = node;
2680-
return;
2681-
}
2682-
2683-
reduction = node;
2684-
t = TensorVar("t" + util::toString(node->var),
2685-
node->getDataType());
2686-
expr = t;
2687-
}
2688-
};
26892689
stmt = ReplaceReductionsWithWheres().rewrite(stmt);
26902690
return stmt;
26912691
}
@@ -2882,46 +2882,6 @@ IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGrap
28822882
}
28832883
}
28842884

2885-
// Replace other reductions with where and forall statements
2886-
struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2887-
using IndexNotationRewriter::visit;
2888-
2889-
Reduction reduction;
2890-
TensorVar t;
2891-
2892-
void visit(const AssignmentNode* node) {
2893-
reduction = Reduction();
2894-
t = TensorVar();
2895-
2896-
IndexExpr rhs = rewrite(node->rhs);
2897-
2898-
// nothing was rewritten
2899-
if (rhs == node->rhs) {
2900-
stmt = node;
2901-
return;
2902-
}
2903-
2904-
taco_iassert(t.defined() && reduction.defined());
2905-
IndexStmt consumer = Assignment(node->lhs, rhs, node->op);
2906-
IndexStmt producer = forall(reduction.getVar(),
2907-
Assignment(t, reduction.getExpr(),
2908-
reduction.getOp()));
2909-
stmt = where(rewrite(consumer), rewrite(producer));
2910-
}
2911-
2912-
void visit(const ReductionNode* node) {
2913-
// only rewrite one reduction at a time
2914-
if (reduction.defined()) {
2915-
expr = node;
2916-
return;
2917-
}
2918-
2919-
reduction = node;
2920-
t = TensorVar("t" + util::toString(node->var),
2921-
node->getDataType());
2922-
expr = t;
2923-
}
2924-
};
29252885
stmt = ReplaceReductionsWithWheres().rewrite(stmt);
29262886
return stmt;
29272887
}

src/lower/lowerer_impl_imperative.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,8 +2086,7 @@ vector<Stmt> LowererImplImperative::codeToInitializeTemporaryParallel(Where wher
20862086
values = ir::Var::make(temporaryAll.getName(),
20872087
temporaryAll.getType().getDataType(),
20882088
true, false);
2089-
// taco_iassert(temporaryAll.getType().getOrder() == 1) << " Temporary order was "
2090-
// << temporaryAll.getType().getOrder(); // TODO
2089+
20912090
Expr size = getTemporarySize(where);
20922091
Expr sizeAll = ir::Mul::make(size, ir::Call::make("omp_get_max_threads", {}, size.type()));
20932092

@@ -2141,8 +2140,7 @@ vector<Stmt> LowererImplImperative::codeToInitializeTemporary(Where where) {
21412140
needComputeValues(where, temporary)) {
21422141
values = ir::Var::make(temporary.getName(),
21432142
temporary.getType().getDataType(), true, false);
2144-
//taco_iassert(temporary.getType().getOrder() == 1)
2145-
// << " Temporary order was " << temporary.getType().getOrder(); // TODO
2143+
21462144
Expr size = getTemporarySize(where);
21472145

21482146
// no decl needed for shared memory

test/tests-scheduling-eval.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,8 @@ IndexStmt scheduleMTTKRPCPU(IndexStmt stmt, Tensor<double> B, int CHUNK_SIZE=16,
163163

164164
stmt = stmt.split(i, i1, i2, CHUNK_SIZE)
165165
.reorder({i1, i2, k, l, j});
166-
167-
cout << stmt << endl;
168166
stmt = stmt.precompute(precomputeExpr, j, j, w);
169-
cout << stmt << endl;
167+
170168
return stmt
171169
.parallelize(i1, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces);
172170
}

0 commit comments

Comments
 (0)