Skip to content

Commit c8a6254

Browse files
committed
ROCM-1254 - fix that functors defined in the cooperative_groups namespace were not usable with user defined types, because they will try to call reduce_sync operations which are only defined with standard types. Add Unit_Thread_Block_Tile_Reduce_Standard_Op_Custom_Type test
1 parent 99b4ac0 commit c8a6254

File tree

3 files changed

+116
-6
lines changed

3 files changed

+116
-6
lines changed

projects/clr/hipamd/include/hip/amd_detail/amd_hip_cooperative_groups_reduce.h

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,60 @@ namespace impl {
2222
using is_param_type_same = __hip_internal::is_same<typename __hip_internal::remove_cvref<T>,
2323
typename __hip_internal::remove_cvref<U>>;
2424

25+
template <typename T, typename = void>
26+
struct has_add : std::false_type {
27+
};
28+
29+
template <typename T>
30+
struct has_add<T,
31+
std::void_t<decltype(__reduce_add_sync<unsigned long long>(0ull, T {}))>
32+
> : std::true_type {};
33+
34+
template <typename T, typename = void>
35+
struct has_min : std::false_type {
36+
};
37+
38+
template <typename T>
39+
struct has_min<T,
40+
std::void_t<decltype(__reduce_min_sync<unsigned long long>(0ull, T {}))>
41+
> : std::true_type {};
42+
43+
template <typename T, typename = void>
44+
struct has_max : std::false_type {
45+
};
46+
47+
template <typename T>
48+
struct has_max<T,
49+
std::void_t<decltype(__reduce_max_sync<unsigned long long>(0ull, T {}))>
50+
> : std::true_type {};
51+
52+
template <typename T, typename = void>
53+
struct has_and : std::false_type {
54+
};
55+
56+
template <typename T>
57+
struct has_and<T,
58+
std::void_t<decltype(__reduce_and_sync<unsigned long long>(0ull, T {}))>
59+
> : std::true_type {};
60+
61+
template <typename T, typename = void>
62+
struct has_or : std::false_type {
63+
};
64+
65+
template <typename T>
66+
struct has_or<T,
67+
std::void_t<decltype(__reduce_or_sync<unsigned long long>(0ull, T {}))>
68+
> : std::true_type {};
69+
70+
template <typename T, typename = void>
71+
struct has_xor : std::false_type {
72+
};
73+
74+
template <typename T>
75+
struct has_xor<T,
76+
std::void_t<decltype(__reduce_xor_sync<unsigned long long>(0ull, T {}))>
77+
> : std::true_type {};
78+
2579
// we can call reduce() only the block tiles that have a compile-time size
2680
template <class TyGroup>
2781
struct isTiledGroup : __hip_internal::false_type {
@@ -69,17 +123,23 @@ __CG_QUALIFIER__ auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> de
69123
// need to apply the active mask
70124
mask &= __activemask();
71125

72-
if constexpr (__hip_internal::is_same<Op, cooperative_groups::plus<Val>>::value) {
126+
if constexpr (__hip_internal::is_same<Op, cooperative_groups::plus<Val>>::value &&
127+
impl::has_add<Val>::value) {
73128
return __reduce_add_sync(mask, val);
74-
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::less<Val>>::value) {
129+
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::less<Val>>::value &&
130+
impl::has_min<Val>::value) {
75131
return __reduce_min_sync(mask, val);
76-
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::greater<Val>>::value) {
132+
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::greater<Val>>::value &&
133+
impl::has_max<Val>::value) {
77134
return __reduce_max_sync(mask, val);
78-
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_and<Val>>::value) {
135+
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_and<Val>>::value &&
136+
impl::has_and<Val>::value) {
79137
return __reduce_and_sync(mask, val);
80-
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_or<Val>>::value) {
138+
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_or<Val>>::value &&
139+
impl::has_or<Val>::value) {
81140
return __reduce_or_sync(mask, val);
82-
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_xor<Val>>::value) {
141+
} else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_xor<Val>>::value &&
142+
impl::has_xor<Val>::value) {
83143
return __reduce_xor_sync(mask, val);
84144
} else {
85145
return __reduce_op_sync(mask, val, op, nullptr);

projects/hip-tests/catch/config/configs/unit/cooperativeGrps.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ cooperativeGrps:
242242
<<: *level_2
243243
# Rock_Window_Failures_on_gfx1151
244244
disabled: [amd_windows]
245+
Unit_Thread_Block_Tile_Reduce_Standard_Op_Custom_Type:
246+
<<: *level_2
245247
Unit_Thread_Block_Tile_Reduce_Trivially_Copyable_Parameters:
246248
<<: *level_2
247249
# Rock_Window_Failures_on_gfx1151

projects/hip-tests/catch/unit/cooperativeGrps/thread_block_tile.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,54 @@ TEST_CASE(Unit_Thread_Block_Tile_Reduce_All_Parameter_Sizes)
10001000
}
10011001
}
10021002

1003+
struct Point {
1004+
int x;
1005+
int y;
1006+
1007+
__device__ Point operator+(const Point& rhs)
1008+
{
1009+
return { x + rhs.x, y + rhs.y };
1010+
}
1011+
1012+
};
1013+
1014+
__global__ void sumPoints(Point* result)
1015+
{
1016+
cg::thread_block mygroup = cg::this_thread_block();
1017+
auto mytile = cg::tiled_partition<32>(mygroup);
1018+
Point input;
1019+
1020+
input.x = threadIdx.x;
1021+
input.y = threadIdx.x;
1022+
1023+
__syncwarp();
1024+
*result = cg::reduce(mytile, input, cooperative_groups::plus<Point> {});
1025+
}
1026+
1027+
// using a standard functor in the cooperative_groups namespace with a type that is not primitive
1028+
TEST_CASE(Unit_Thread_Block_Tile_Reduce_Standard_Op_Custom_Type)
1029+
{
1030+
LinearAllocGuard<Point> h_result(LinearAllocs::malloc, sizeof(Point) * 32);
1031+
LinearAllocGuard<Point> d_result(LinearAllocs::hipMalloc, sizeof(Point) * 32);
1032+
dim3 gridDim = { 1 };
1033+
dim3 blockDim = { 32 };
1034+
void* devicePtr = d_result.ptr();
1035+
void* args[] = { &devicePtr };
1036+
int expected = 31 * 16;
1037+
1038+
HIP_CHECK(hipLaunchCooperativeKernel(reinterpret_cast<void*>(sumPoints), gridDim, blockDim, args, 0, nullptr));
1039+
HIP_CHECK(hipDeviceSynchronize());
1040+
HIP_CHECK(hipGetLastError());
1041+
HIP_CHECK(hipMemcpy(h_result.host_ptr(), d_result.ptr(),
1042+
h_result.size_bytes(), hipMemcpyDeviceToHost));
1043+
1044+
for (int i = 0; i < 32; i++) {
1045+
INFO("Expected x: " << expected << " got: " << *h_result.host_ptr());
1046+
INFO("Expected y: " << expected << " got: " << *h_result.host_ptr());
1047+
REQUIRE((h_result.host_ptr()->x == expected && h_result.host_ptr()->y == expected));
1048+
}
1049+
}
1050+
10031051
TEMPLATE_TEST_CASE(Unit_Thread_Block_Coalesced_Reduce_arithmetic, int, unsigned int, long long,
10041052
unsigned long long, float, half, double)
10051053
{

0 commit comments

Comments
 (0)