@@ -437,10 +437,16 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) {
437
437
438
438
TEST (workspaces, DISABLED_tile_dotProduct_1) {
439
439
// FIXME: Disabled because currently the precompute algorithm does not appropriately
440
- // optimize = from += when rewriting a statement for BOTH the producer and consumer
441
- // side of a where statement insertion.
442
- // Although always using += is CORRECT functionally, this fails the GPU tests since it
443
- // would result in scattering.
440
+ // find the correct forall substmt to next the WhereNode in after i has been
441
+ // split into i0 and i1. As an example, the first precompute below is incorrect
442
+ // since it should transform
443
+ // forall(i0, forall(i1, A() += B(i) * C(i))) -->
444
+ // forall(i0, where(forall(i1, A() += ws(i1)), forall(i1, ws(i1) += B(i) * C(i))))
445
+ //
446
+ // But currently the algorithm does
447
+ // forall(i0, forall(i1, A() += B(i) * C(i))) -->
448
+ // where(forall(i1, A() += ws(i1)), forall(i0, forall(i1, ws(i1) += B(i) * C(i))))
449
+
444
450
int N = 1024 ;
445
451
Tensor<double > A (" A" );
446
452
Tensor<double > B (" B" , {N}, Format ({Dense}));
@@ -470,17 +476,26 @@ TEST(workspaces, DISABLED_tile_dotProduct_1) {
470
476
stmt = stmt.bound (i, i_bounded, (size_t )N, BoundType::MaxExact)
471
477
.split (i_bounded, i0, i1, 32 );
472
478
stmt = stmt.precompute (precomputedExpr, i1, i1, precomputed);
473
-
474
- stmt = stmt.precompute (BExpr, i1, i1, B_new)
475
- .precompute (CExpr, i1, i1, C_new);
476
-
477
-
479
+ stmt = stmt.precompute (BExpr, i1, i1, B_new)
480
+ .precompute (CExpr, i1, i1, C_new);
481
+
478
482
stmt = stmt.concretize ();
479
483
480
484
A.compile (stmt);
481
485
A.assemble ();
482
486
A.compute ();
483
487
488
+ ir::IRPrinter irp = ir::IRPrinter (cout);
489
+
490
+ cout << stmt << endl;
491
+
492
+ std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default (cout, ir::CodeGen::ImplementationGen);
493
+ ir::Stmt compute = lower (stmt, " compute" , false , true );
494
+
495
+ irp.print (compute);
496
+ cout << endl;
497
+ codegen->compile (compute, false );
498
+
484
499
Tensor<double > expected (" expected" );
485
500
expected () = B (i) * C (i);
486
501
expected.compile ();
@@ -543,3 +558,51 @@ TEST(workspaces, DISABLED_tile_dotProduct_2) {
543
558
ASSERT_TENSOR_EQ (expected, A);
544
559
}
545
560
561
+ TEST (workspaces, tile_dotProduct_3) {
562
+ int N = 1024 ;
563
+ Tensor<double > A (" A" );
564
+ Tensor<double > B (" B" , {N}, Format ({Dense}));
565
+ Tensor<double > C (" C" , {N}, Format ({Dense}));
566
+
567
+ for (int i = 0 ; i < N; i++) {
568
+ B.insert ({i}, (double ) i);
569
+ C.insert ({i}, (double ) i);
570
+ }
571
+
572
+ B.pack ();
573
+ C.pack ();
574
+
575
+ IndexVar i (" i" );
576
+ IndexVar i_bounded (" i_bounded" );
577
+ IndexVar i0 (" i0" ), i1 (" i1" );
578
+ IndexExpr BExpr = B (i);
579
+ IndexExpr CExpr = C (i);
580
+ IndexExpr precomputedExpr = (BExpr) * (CExpr);
581
+ A () = precomputedExpr;
582
+
583
+ IndexStmt stmt = A.getAssignment ().concretize ();
584
+ TensorVar B_new (" B_new" , Type (Float64, {(size_t )N}), taco::dense);
585
+ TensorVar C_new (" C_new" , Type (Float64, {(size_t )N}), taco::dense);
586
+ TensorVar precomputed (" precomputed" , Type (Float64, {(size_t )N}), taco::dense);
587
+
588
+ stmt = stmt.bound (i, i_bounded, (size_t )N, BoundType::MaxExact)
589
+ .split (i_bounded, i0, i1, 32 );
590
+ stmt = stmt.precompute (precomputedExpr, i0, i0, precomputed);
591
+
592
+ stmt = stmt.precompute (BExpr, i1, i1, B_new)
593
+ .precompute (CExpr, i1, i1, C_new);
594
+
595
+
596
+ stmt = stmt.concretize ();
597
+
598
+ A.compile (stmt);
599
+ A.assemble ();
600
+ A.compute ();
601
+
602
+ Tensor<double > expected (" expected" );
603
+ expected () = B (i) * C (i);
604
+ expected.compile ();
605
+ expected.assemble ();
606
+ expected.compute ();
607
+ ASSERT_TENSOR_EQ (expected, A);
608
+ }
0 commit comments