Skip to content

Commit 6ba7b52

Browse files
[SYCL] Fix SYCL kernel lambda argument type detection (#11679)
We have a helper which is used to extract a type of the first SYCL kernel lambda argument to do some error-checking and special handling based on that. That check, however, was missing a case when a kernel lambda is also accepting `kernel_handler` argument, always falling back to a suggested type in that case. This led to a situations where we couldn't compile code like: ```c++ sycl::queue q; q.parallel_for(sycl::range{1}, [=](sycl::item<1, false>, kernel_handler) {}); ``` This patch adds extra specializations of some internal helpers to fix the error. This is a follow-up from #11625
1 parent 4156f78 commit 6ba7b52

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,15 @@ static Arg member_ptr_helper(RetType (Func::*)(Arg) const);
189189
template <typename RetType, typename Func, typename Arg>
190190
static Arg member_ptr_helper(RetType (Func::*)(Arg));
191191

192-
// template <typename RetType, typename Func>
193-
// static void member_ptr_helper(RetType (Func::*)() const);
192+
// Version with two arguments to handle the case when kernel_handler is passed
193+
// to a lambda
194+
template <typename RetType, typename Func, typename Arg1, typename Arg2>
195+
static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2) const);
194196

195-
// template <typename RetType, typename Func>
196-
// static void member_ptr_helper(RetType (Func::*)());
197+
// Non-const version of the above template to match functors whose 'operator()'
198+
// is declared w/o the 'const' qualifier.
199+
template <typename RetType, typename Func, typename Arg1, typename Arg2>
200+
static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2));
197201

198202
template <typename F, typename SuggestedArgType>
199203
decltype(member_ptr_helper(&F::operator())) argument_helper(int);
@@ -1280,7 +1284,7 @@ class __SYCL_EXPORT handler {
12801284
using KName = std::conditional_t<std::is_same<KernelType, NameT>::value,
12811285
decltype(Wrapper), NameWT>;
12821286

1283-
kernel_parallel_for_wrapper<KName, item<Dims>, decltype(Wrapper),
1287+
kernel_parallel_for_wrapper<KName, TransformedArgType, decltype(Wrapper),
12841288
PropertiesT>(Wrapper);
12851289
#ifndef __SYCL_DEVICE_ONLY__
12861290
// We are executing over the rounded range, but there are still
@@ -1290,7 +1294,7 @@ class __SYCL_EXPORT handler {
12901294
// of the user range, instead of the rounded range.
12911295
detail::checkValueRange<Dims>(UserRange);
12921296
MNDRDesc.set(*RoundedRange);
1293-
StoreLambda<KName, decltype(Wrapper), Dims, item<Dims>>(
1297+
StoreLambda<KName, decltype(Wrapper), Dims, TransformedArgType>(
12941298
std::move(Wrapper));
12951299
setType(detail::CG::Kernel);
12961300
#endif

sycl/test/basic_tests/handler/parallel_for_args.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ int main() {
3636
q.parallel_for(r2, [=](sycl::item<2> it) {});
3737
q.parallel_for(r3, [=](sycl::item<3> it) {});
3838

39+
q.parallel_for(r1, [=](sycl::item<1, false> it) {});
40+
q.parallel_for(r2, [=](sycl::item<2, false> it) {});
41+
q.parallel_for(r3, [=](sycl::item<3, false> it) {});
42+
43+
// int, size_t -> sycl::item
44+
q.parallel_for(r1, [=](int it) {});
45+
q.parallel_for(r1, [=](size_t it) {});
46+
3947
// sycl::item -> sycl::id
4048
q.parallel_for(r1, [=](sycl::id<1> it) {});
4149
q.parallel_for(r2, [=](sycl::id<2> it) {});
@@ -51,6 +59,13 @@ int main() {
5159
q.parallel_for(r2, [=](sycl::item<2> it, sycl::kernel_handler kh) {});
5260
q.parallel_for(r3, [=](sycl::item<3> it, sycl::kernel_handler kh) {});
5361

62+
q.parallel_for(r1, [=](int it, sycl::kernel_handler kh) {});
63+
q.parallel_for(r1, [=](size_t it, sycl::kernel_handler kh) {});
64+
65+
q.parallel_for(r1, [=](sycl::item<1, false> it, sycl::kernel_handler kh) {});
66+
q.parallel_for(r2, [=](sycl::item<2, false> it, sycl::kernel_handler kh) {});
67+
q.parallel_for(r3, [=](sycl::item<3, false> it, sycl::kernel_handler kh) {});
68+
5469
q.parallel_for(r1, [=](sycl::id<1> it, sycl::kernel_handler kh) {});
5570
q.parallel_for(r2, [=](sycl::id<2> it, sycl::kernel_handler kh) {});
5671
q.parallel_for(r3, [=](sycl::id<3> it, sycl::kernel_handler kh) {});
@@ -90,5 +105,4 @@ int main() {
90105
[=](ConvertibleFromNDItem<3> it, sycl::kernel_handler kh) {});
91106

92107
// TODO: consider adding test cases for hierarchical parallelism
93-
// TODO: consider adding cases for sycl::item with offset
94108
}

0 commit comments

Comments
 (0)