Skip to content

Commit 1404ff1

Browse files
[SYCL][Reduction] Fix identityless span reductions in variadic (#8655)
Using a span reduction without an identity in a parallel_for with multiple reductions currently fail due to a missed use of identity. This commit addresses this missed case. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent bd64197 commit 1404ff1

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,11 +2151,15 @@ void reduAuxCGFuncImplArrayHelper(bool IsOneWG, nd_item<Dims> NDIt, size_t LID,
21512151
// Add the initial value of user's variable to the final result.
21522152
if (LID == 0) {
21532153
size_t GrID = NDIt.get_group_linear_id();
2154-
Out[GrID * NElements + E] =
2155-
IsOneWG ? BOp(LocalReds[0], IsInitializeToIdentity
2156-
? IdentityContainer.getIdentity()
2157-
: Out[E])
2158-
: LocalReds[0];
2154+
if constexpr (Reduction::has_identity) {
2155+
Out[GrID * NElements + E] =
2156+
IsOneWG ? BOp(LocalReds[0], IsInitializeToIdentity
2157+
? IdentityContainer.getIdentity()
2158+
: Out[E])
2159+
: LocalReds[0];
2160+
} else {
2161+
Out[GrID * NElements + E] = LocalReds[0];
2162+
}
21592163
}
21602164

21612165
// Ensure item 0 is finished with LocalReds before next iteration
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s
2+
3+
// Tests that identityless reductions compile when applied to a span.
4+
5+
#include <sycl/sycl.hpp>
6+
7+
template <class T> struct PlusWithoutIdentity {
8+
T operator()(const T &A, const T &B) const { return A + B; }
9+
};
10+
11+
int main() {
12+
sycl::queue Q;
13+
14+
int *ScalarMem = sycl::malloc_shared<int>(1, Q);
15+
int *SpanMem = sycl::malloc_shared<int>(8, Q);
16+
auto ScalarRed = sycl::reduction(ScalarMem, PlusWithoutIdentity<int>{});
17+
auto SpanRed = sycl::reduction(sycl::span<int, 8>{SpanMem, 8},
18+
PlusWithoutIdentity<int>{});
19+
Q.parallel_for(sycl::range<1>{1024}, ScalarRed, SpanRed,
20+
[=](sycl::item<1>, auto &, auto &) {});
21+
Q.parallel_for(sycl::nd_range<1>{1024, 1024}, ScalarRed, SpanRed,
22+
[=](sycl::nd_item<1>, auto &, auto &) {});
23+
24+
return 0;
25+
}

0 commit comments

Comments
 (0)