Skip to content

Commit d460e29

Browse files
committed
Add in another test for tiled_dotProduct
1 parent 3d879eb commit d460e29

File tree

1 file changed

+72
-9
lines changed

1 file changed

+72
-9
lines changed

test/tests-workspaces.cpp

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,16 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) {
437437

438438
TEST(workspaces, DISABLED_tile_dotProduct_1) {
439439
// FIXME: Disabled because currently the precompute algorithm does not appropriately
440-
// optimize = from += when rewriting a statement for BOTH the producer and consumer
441-
// side of a where statement insertion.
442-
// Although always using += is CORRECT functionally, this fails the GPU tests since it
443-
// would result in scattering.
440+
// find the correct forall substmt to next the WhereNode in after i has been
441+
// split into i0 and i1. As an example, the first precompute below is incorrect
442+
// since it should transform
443+
// forall(i0, forall(i1, A() += B(i) * C(i))) -->
444+
// forall(i0, where(forall(i1, A() += ws(i1)), forall(i1, ws(i1) += B(i) * C(i))))
445+
//
446+
// But currently the algorithm does
447+
// forall(i0, forall(i1, A() += B(i) * C(i))) -->
448+
// where(forall(i1, A() += ws(i1)), forall(i0, forall(i1, ws(i1) += B(i) * C(i))))
449+
444450
int N = 1024;
445451
Tensor<double> A("A");
446452
Tensor<double> B("B", {N}, Format({Dense}));
@@ -470,17 +476,26 @@ TEST(workspaces, DISABLED_tile_dotProduct_1) {
470476
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
471477
.split(i_bounded, i0, i1, 32);
472478
stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed);
473-
474-
stmt = stmt.precompute(BExpr, i1, i1, B_new)
475-
.precompute(CExpr, i1, i1, C_new);
476-
477-
479+
stmt = stmt.precompute(BExpr, i1, i1, B_new)
480+
.precompute(CExpr, i1, i1, C_new);
481+
478482
stmt = stmt.concretize();
479483

480484
A.compile(stmt);
481485
A.assemble();
482486
A.compute();
483487

488+
ir::IRPrinter irp = ir::IRPrinter(cout);
489+
490+
cout << stmt << endl;
491+
492+
std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default(cout, ir::CodeGen::ImplementationGen);
493+
ir::Stmt compute = lower(stmt, "compute", false, true);
494+
495+
irp.print(compute);
496+
cout << endl;
497+
codegen->compile(compute, false);
498+
484499
Tensor<double> expected("expected");
485500
expected() = B(i) * C(i);
486501
expected.compile();
@@ -543,3 +558,51 @@ TEST(workspaces, DISABLED_tile_dotProduct_2) {
543558
ASSERT_TENSOR_EQ(expected, A);
544559
}
545560

561+
TEST(workspaces, tile_dotProduct_3) {
562+
int N = 1024;
563+
Tensor<double> A("A");
564+
Tensor<double> B("B", {N}, Format({Dense}));
565+
Tensor<double> C("C", {N}, Format({Dense}));
566+
567+
for (int i = 0; i < N; i++) {
568+
B.insert({i}, (double) i);
569+
C.insert({i}, (double) i);
570+
}
571+
572+
B.pack();
573+
C.pack();
574+
575+
IndexVar i("i");
576+
IndexVar i_bounded("i_bounded");
577+
IndexVar i0("i0"), i1("i1");
578+
IndexExpr BExpr = B(i);
579+
IndexExpr CExpr = C(i);
580+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
581+
A() = precomputedExpr;
582+
583+
IndexStmt stmt = A.getAssignment().concretize();
584+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
585+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
586+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
587+
588+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
589+
.split(i_bounded, i0, i1, 32);
590+
stmt = stmt.precompute(precomputedExpr, i0, i0, precomputed);
591+
592+
stmt = stmt.precompute(BExpr, i1, i1, B_new)
593+
.precompute(CExpr, i1, i1, C_new);
594+
595+
596+
stmt = stmt.concretize();
597+
598+
A.compile(stmt);
599+
A.assemble();
600+
A.compute();
601+
602+
Tensor<double> expected("expected");
603+
expected() = B(i) * C(i);
604+
expected.compile();
605+
expected.assemble();
606+
expected.compute();
607+
ASSERT_TENSOR_EQ(expected, A);
608+
}

0 commit comments

Comments
 (0)