diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index 61111af7f40c9..cbb906516b004 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -1313,6 +1313,32 @@ pi_result piKernelGetSubGroupInfo(pi_kernel kernel, pi_device device, cast(param_name), input_value_size, input_value, sizeof(size_t), &ret_val, param_value_size_ret)); + if (ret_err == CL_INVALID_OPERATION) { + // clGetKernelSubGroupInfo returns CL_INVALID_OPERATION if the device does + // not support subgroups. + + if (param_name == PI_KERNEL_MAX_NUM_SUB_GROUPS) { + ret_val = 1; // Minimum required by SYCL 2020 spec + ret_err = CL_SUCCESS; + } else if (param_name == PI_KERNEL_COMPILE_NUM_SUB_GROUPS) { + ret_val = 0; // Not specified by kernel + ret_err = CL_SUCCESS; + } else if (param_name == PI_KERNEL_MAX_SUB_GROUP_SIZE) { + // Return the maximum work group size for the kernel + size_t kernel_work_group_size = 0; + pi_result pi_ret_err = piKernelGetGroupInfo( + kernel, device, PI_KERNEL_GROUP_INFO_WORK_GROUP_SIZE, sizeof(size_t), + &kernel_work_group_size, nullptr); + if (pi_ret_err != PI_SUCCESS) + return pi_ret_err; + ret_val = kernel_work_group_size; + ret_err = CL_SUCCESS; + } else if (param_name == PI_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL) { + ret_val = 0; // Not specified by kernel + ret_err = CL_SUCCESS; + } + } + if (ret_err != CL_SUCCESS) return cast(ret_err); diff --git a/sycl/test-e2e/Basic/kernel_info.cpp b/sycl/test-e2e/Basic/kernel_info.cpp index 0703ec2d97213..b41aa2e97ec93 100644 --- a/sycl/test-e2e/Basic/kernel_info.cpp +++ b/sycl/test-e2e/Basic/kernel_info.cpp @@ -54,6 +54,18 @@ int main() { const size_t prefWGSizeMult = krn.get_info< info::kernel_device_specific::preferred_work_group_size_multiple>(dev); assert(prefWGSizeMult > 0); + const cl_uint maxSgSize = + krn.get_info(dev); + assert(0 < maxSgSize && maxSgSize <= wgSize); + const cl_uint compileSgSize = + krn.get_info(dev); + assert(compileSgSize <= maxSgSize); + const cl_uint maxNumSg = + krn.get_info(dev); + assert(0 < maxNumSg); + const cl_uint compileNumSg = + krn.get_info(dev); + assert(compileNumSg <= maxNumSg); try { krn.get_info(dev);