@@ -1120,8 +1120,39 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
1120
1120
1121
1121
bool ProvenanceGraph::isRecoverable (taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1122
1122
// all children are either defined or recoverable from their children
1123
+ // This checks the definedVars list to determine where in the statement the variables are trying to be
1124
+ // recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125
+ vector<IndexVar> producers;
1126
+ vector<IndexVar> consumers;
1127
+ for (auto & def : defined) {
1128
+ if (childRelMap.count (def) && childRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1129
+ consumers.push_back (def);
1130
+ }
1131
+ if (parentRelMap.count (def) && parentRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1132
+ producers.push_back (def);
1133
+ }
1134
+ }
1135
+
1136
+ return isRecoverablePrecompute (indexVar, defined, producers, consumers);
1137
+ }
1138
+
1139
+ bool ProvenanceGraph::isRecoverablePrecompute (taco::IndexVar indexVar, std::set<taco::IndexVar> defined,
1140
+ vector<IndexVar> producers, vector<IndexVar> consumers) const {
1141
+ vector<IndexVar> childPrecompute;
1142
+ if (std::find (consumers.begin (), consumers.end (), indexVar) != consumers.end ()) {
1143
+ return true ;
1144
+ }
1145
+ if (!producers.empty () && (childRelMap.count (indexVar) &&
1146
+ childRelMap.at (indexVar).getRelType () == IndexVarRelType::PRECOMPUTE)) {
1147
+ auto precomputeChild = getChildren (indexVar)[0 ];
1148
+ if (std::find (producers.begin (), producers.end (), precomputeChild) != producers.end ()) {
1149
+ return true ;
1150
+ }
1151
+ return isRecoverablePrecompute (precomputeChild, defined, producers, consumers);
1152
+ }
1123
1153
for (const IndexVar& child : getChildren (indexVar)) {
1124
- if (!defined.count (child) && (isFullyDerived (child) || !isRecoverable (child, defined))) {
1154
+ if (!defined.count (child) && (isFullyDerived (child) ||
1155
+ !isRecoverablePrecompute (child, defined, producers, consumers))) {
1125
1156
return false ;
1126
1157
}
1127
1158
}
0 commit comments