Skip to content

Commit d464c7c

Browse files
authored
Merge pull request #474 from tensor-compiler/test-pr470
Modified PR 470: update isRecoverable check and iterators to fix precompute bugs
2 parents 88fcb6c + 4cf3b56 commit d464c7c

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed

include/taco/index_notation/provenance_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ class ProvenanceGraph {
354354
/// Node is recoverable if children appear in defined
355355
bool isRecoverable(IndexVar indexVar, std::set<IndexVar> defined) const;
356356

357+
/// isRecoverable helper method to handle precompute relations and where statements in the provenance graph
358+
bool isRecoverablePrecompute(IndexVar indexVar, std::set<IndexVar> defined, std::vector<IndexVar> producers, std::vector<IndexVar> consumers) const;
359+
357360
/// Node is recoverable if at most 1 unknown variable in relationship (parents + siblings)
358361
bool isChildRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const;
359362

src/index_notation/index_notation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,14 @@ IndexStmt IndexStmt::precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorV
15141514
IndexStmt transformed = *this;
15151515
string reason;
15161516

1517+
if (i != iw) {
1518+
IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw));
1519+
transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason);
1520+
if (!transformed.defined()) {
1521+
taco_uerror << reason;
1522+
}
1523+
}
1524+
15171525
transformed = Transformation(Precompute(expr, i, iw, workspace)).apply(transformed, &reason);
15181526
if (!transformed.defined()) {
15191527
taco_uerror << reason;

src/index_notation/provenance_graph.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,39 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120

11211121
bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
11221122
// all children are either defined or recoverable from their children
1123+
// This checks the definedVars list to determine where in the statement the variables are trying to be
1124+
// recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125+
vector<IndexVar> producers;
1126+
vector<IndexVar> consumers;
1127+
for (auto& def : defined) {
1128+
if (childRelMap.count(def) && childRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1129+
consumers.push_back(def);
1130+
}
1131+
if (parentRelMap.count(def) && parentRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1132+
producers.push_back(def);
1133+
}
1134+
}
1135+
1136+
return isRecoverablePrecompute(indexVar, defined, producers, consumers);
1137+
}
1138+
1139+
bool ProvenanceGraph::isRecoverablePrecompute(taco::IndexVar indexVar, std::set<taco::IndexVar> defined,
1140+
vector<IndexVar> producers, vector<IndexVar> consumers) const {
1141+
vector<IndexVar> childPrecompute;
1142+
if (std::find(consumers.begin(), consumers.end(), indexVar) != consumers.end()) {
1143+
return true;
1144+
}
1145+
if (!producers.empty() && (childRelMap.count(indexVar) &&
1146+
childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
1147+
auto precomputeChild = getChildren(indexVar)[0];
1148+
if (std::find(producers.begin(), producers.end(), precomputeChild) != producers.end()) {
1149+
return true;
1150+
}
1151+
return isRecoverablePrecompute(precomputeChild, defined, producers, consumers);
1152+
}
11231153
for (const IndexVar& child : getChildren(indexVar)) {
1124-
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
1154+
if (!defined.count(child) && (isFullyDerived(child) ||
1155+
!isRecoverablePrecompute(child, defined, producers, consumers))) {
11251156
return false;
11261157
}
11271158
}

src/lower/iterator.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,15 @@ Iterators::Iterators(IndexStmt stmt, const map<TensorVar, Expr>& tensorVars)
531531
underivedAdded.insert(underived);
532532
}
533533
}
534+
535+
// Insert all children of current index variable into iterators as well
536+
for (const IndexVar& child : provGraph.getChildren(n->indexVar)) {
537+
if (!underivedAdded.count(child)) {
538+
content->modeIterators.insert({child, child});
539+
underivedAdded.insert(child);
540+
}
541+
}
542+
534543
m->match(n->stmt);
535544
})
536545
);

0 commit comments

Comments
 (0)