26
26
#include " taco/tensor.h"
27
27
28
28
#include " taco/util/name_generator.h"
29
+ #include " taco/util/scopedset.h"
29
30
#include " taco/util/scopedmap.h"
30
31
#include " taco/util/strings.h"
31
32
#include " taco/util/collections.h"
@@ -1525,11 +1526,9 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
1525
1526
IndexStmt IndexStmt::precompute (IndexExpr expr, std::vector<IndexVar> i_vars,
1526
1527
std::vector<IndexVar> iw_vars, TensorVar workspace) const {
1527
1528
1528
- // TODO: need to assert they are same length
1529
1529
IndexStmt transformed = *this ;
1530
1530
string reason;
1531
1531
1532
- // FIXME: need to re-enable this later
1533
1532
taco_uassert (i_vars.size () == iw_vars.size ()) << " The precompute transformation requires"
1534
1533
<< " i_vars and iw_vars to be the same size" ;
1535
1534
for (int l = 0 ; l < (int ) i_vars.size (); l++) {
@@ -2343,18 +2342,18 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
2343
2342
// Reduction notation until proved otherwise
2344
2343
bool isReduction = true ;
2345
2344
2346
- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2345
+ util::ScopedSet <IndexVar> boundVars;
2347
2346
vector<IndexVar> boundVarsList;
2348
2347
for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
2349
- boundVars.insert ({var, 0 });
2348
+ boundVars.insert ({var});
2350
2349
boundVarsList.push_back (var);
2351
2350
}
2352
2351
2353
2352
match (stmt,
2354
2353
std::function<void (const ReductionNode*,Matcher*)>([&](
2355
2354
const ReductionNode* op, Matcher* ctx) {
2356
2355
boundVars.scope ();
2357
- boundVars.insert ({op->var , 0 });
2356
+ boundVars.insert ({op->var });
2358
2357
ctx->match (op->a );
2359
2358
boundVars.unscope ();
2360
2359
}),
@@ -2386,18 +2385,18 @@ bool isReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph, std
2386
2385
// Reduction notation until proved otherwise
2387
2386
bool isReduction = true ;
2388
2387
2389
- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2388
+ util::ScopedSet <IndexVar> boundVars;
2390
2389
vector<IndexVar> boundVarsList;
2391
2390
for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
2392
- boundVars.insert ({var, 0 });
2391
+ boundVars.insert ({var});
2393
2392
boundVarsList.push_back (var);
2394
2393
}
2395
2394
2396
2395
match (stmt,
2397
2396
std::function<void (const ReductionNode*,Matcher*)>([&](
2398
2397
const ReductionNode* op, Matcher* ctx) {
2399
2398
boundVars.scope ();
2400
- boundVars.insert ({op->var , 0 });
2399
+ boundVars.insert ({op->var });
2401
2400
ctx->match (op->a );
2402
2401
boundVars.unscope ();
2403
2402
}),
@@ -2441,7 +2440,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
2441
2440
2442
2441
bool inWhereProducer = false ;
2443
2442
bool inWhereConsumer = false ;
2444
- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2443
+ util::ScopedSet <IndexVar> boundVars;
2445
2444
std::set<IndexVar> definedVars; // used to check if all variables recoverable TODO: need to actually use scope like above
2446
2445
2447
2446
ProvenanceGraph provGraph = ProvenanceGraph (stmt);
@@ -2450,7 +2449,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
2450
2449
std::function<void (const ForallNode*,Matcher*)>([&](const ForallNode* op,
2451
2450
Matcher* ctx) {
2452
2451
boundVars.scope ();
2453
- boundVars.insert ({op->indexVar , 0 });
2452
+ boundVars.insert ({op->indexVar });
2454
2453
definedVars.insert (op->indexVar );
2455
2454
ctx->match (op->stmt );
2456
2455
boundVars.unscope ();
@@ -2606,6 +2605,47 @@ IndexStmt makeReductionNotation(IndexStmt stmt) {
2606
2605
return makeReductionNotation (to<Assignment>(stmt));
2607
2606
}
2608
2607
2608
+ // Replace other reductions with where and forall statements
2609
+ struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2610
+ using IndexNotationRewriter::visit;
2611
+
2612
+ Reduction reduction;
2613
+ TensorVar t;
2614
+
2615
+ void visit (const AssignmentNode* node) {
2616
+ reduction = Reduction ();
2617
+ t = TensorVar ();
2618
+
2619
+ IndexExpr rhs = rewrite (node->rhs );
2620
+
2621
+ // nothing was rewritten
2622
+ if (rhs == node->rhs ) {
2623
+ stmt = node;
2624
+ return ;
2625
+ }
2626
+
2627
+ taco_iassert (t.defined () && reduction.defined ());
2628
+ IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2629
+ IndexStmt producer = forall (reduction.getVar (),
2630
+ Assignment (t, reduction.getExpr (),
2631
+ reduction.getOp ()));
2632
+ stmt = where (rewrite (consumer), rewrite (producer));
2633
+ }
2634
+
2635
+ void visit (const ReductionNode* node) {
2636
+ // only rewrite one reduction at a time
2637
+ if (reduction.defined ()) {
2638
+ expr = node;
2639
+ return ;
2640
+ }
2641
+
2642
+ reduction = node;
2643
+ t = TensorVar (" t" + util::toString (node->var ),
2644
+ node->getDataType ());
2645
+ expr = t;
2646
+ }
2647
+ };
2648
+
2609
2649
IndexStmt makeConcreteNotation (IndexStmt stmt) {
2610
2650
std::string reason;
2611
2651
taco_iassert (isReductionNotation (stmt, &reason))
@@ -2646,46 +2686,6 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) {
2646
2686
stmt = forall (i, stmt);
2647
2687
}
2648
2688
2649
- // Replace other reductions with where and forall statements
2650
- struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2651
- using IndexNotationRewriter::visit;
2652
-
2653
- Reduction reduction;
2654
- TensorVar t;
2655
-
2656
- void visit (const AssignmentNode* node) {
2657
- reduction = Reduction ();
2658
- t = TensorVar ();
2659
-
2660
- IndexExpr rhs = rewrite (node->rhs );
2661
-
2662
- // nothing was rewritten
2663
- if (rhs == node->rhs ) {
2664
- stmt = node;
2665
- return ;
2666
- }
2667
-
2668
- taco_iassert (t.defined () && reduction.defined ());
2669
- IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2670
- IndexStmt producer = forall (reduction.getVar (),
2671
- Assignment (t, reduction.getExpr (),
2672
- reduction.getOp ()));
2673
- stmt = where (rewrite (consumer), rewrite (producer));
2674
- }
2675
-
2676
- void visit (const ReductionNode* node) {
2677
- // only rewrite one reduction at a time
2678
- if (reduction.defined ()) {
2679
- expr = node;
2680
- return ;
2681
- }
2682
-
2683
- reduction = node;
2684
- t = TensorVar (" t" + util::toString (node->var ),
2685
- node->getDataType ());
2686
- expr = t;
2687
- }
2688
- };
2689
2689
stmt = ReplaceReductionsWithWheres ().rewrite (stmt);
2690
2690
return stmt;
2691
2691
}
@@ -2882,46 +2882,6 @@ IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGrap
2882
2882
}
2883
2883
}
2884
2884
2885
- // Replace other reductions with where and forall statements
2886
- struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2887
- using IndexNotationRewriter::visit;
2888
-
2889
- Reduction reduction;
2890
- TensorVar t;
2891
-
2892
- void visit (const AssignmentNode* node) {
2893
- reduction = Reduction ();
2894
- t = TensorVar ();
2895
-
2896
- IndexExpr rhs = rewrite (node->rhs );
2897
-
2898
- // nothing was rewritten
2899
- if (rhs == node->rhs ) {
2900
- stmt = node;
2901
- return ;
2902
- }
2903
-
2904
- taco_iassert (t.defined () && reduction.defined ());
2905
- IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2906
- IndexStmt producer = forall (reduction.getVar (),
2907
- Assignment (t, reduction.getExpr (),
2908
- reduction.getOp ()));
2909
- stmt = where (rewrite (consumer), rewrite (producer));
2910
- }
2911
-
2912
- void visit (const ReductionNode* node) {
2913
- // only rewrite one reduction at a time
2914
- if (reduction.defined ()) {
2915
- expr = node;
2916
- return ;
2917
- }
2918
-
2919
- reduction = node;
2920
- t = TensorVar (" t" + util::toString (node->var ),
2921
- node->getDataType ());
2922
- expr = t;
2923
- }
2924
- };
2925
2885
stmt = ReplaceReductionsWithWheres ().rewrite (stmt);
2926
2886
return stmt;
2927
2887
}
0 commit comments