@@ -22,6 +22,60 @@ namespace impl {
2222 using is_param_type_same = __hip_internal::is_same<typename __hip_internal::remove_cvref<T>,
2323 typename __hip_internal::remove_cvref<U>>;
2424
25+ template <typename T, typename = void >
26+ struct has_add : std::false_type {
27+ };
28+
29+ template <typename T>
30+ struct has_add <T,
31+ std::void_t <decltype (__reduce_add_sync<unsigned long long >(0ull , T {}))>
32+ > : std::true_type {};
33+
34+ template <typename T, typename = void >
35+ struct has_min : std::false_type {
36+ };
37+
38+ template <typename T>
39+ struct has_min <T,
40+ std::void_t <decltype (__reduce_min_sync<unsigned long long >(0ull , T {}))>
41+ > : std::true_type {};
42+
43+ template <typename T, typename = void >
44+ struct has_max : std::false_type {
45+ };
46+
47+ template <typename T>
48+ struct has_max <T,
49+ std::void_t <decltype (__reduce_max_sync<unsigned long long >(0ull , T {}))>
50+ > : std::true_type {};
51+
52+ template <typename T, typename = void >
53+ struct has_and : std::false_type {
54+ };
55+
56+ template <typename T>
57+ struct has_and <T,
58+ std::void_t <decltype (__reduce_and_sync<unsigned long long >(0ull , T {}))>
59+ > : std::true_type {};
60+
61+ template <typename T, typename = void >
62+ struct has_or : std::false_type {
63+ };
64+
65+ template <typename T>
66+ struct has_or <T,
67+ std::void_t <decltype (__reduce_or_sync<unsigned long long >(0ull , T {}))>
68+ > : std::true_type {};
69+
70+ template <typename T, typename = void >
71+ struct has_xor : std::false_type {
72+ };
73+
74+ template <typename T>
75+ struct has_xor <T,
76+ std::void_t <decltype (__reduce_xor_sync<unsigned long long >(0ull , T {}))>
77+ > : std::true_type {};
78+
2579 // we can call reduce() only the block tiles that have a compile-time size
2680 template <class TyGroup >
2781 struct isTiledGroup : __hip_internal::false_type {
@@ -69,17 +123,23 @@ __CG_QUALIFIER__ auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> de
69123 // need to apply the active mask
70124 mask &= __activemask ();
71125
72- if constexpr (__hip_internal::is_same<Op, cooperative_groups::plus<Val>>::value) {
126+ if constexpr (__hip_internal::is_same<Op, cooperative_groups::plus<Val>>::value &&
127+ impl::has_add<Val>::value) {
73128 return __reduce_add_sync (mask, val);
74- } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::less<Val>>::value) {
129+ } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::less<Val>>::value &&
130+ impl::has_min<Val>::value) {
75131 return __reduce_min_sync (mask, val);
76- } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::greater<Val>>::value) {
132+ } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::greater<Val>>::value &&
133+ impl::has_max<Val>::value) {
77134 return __reduce_max_sync (mask, val);
78- } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_and<Val>>::value) {
135+ } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_and<Val>>::value &&
136+ impl::has_and<Val>::value) {
79137 return __reduce_and_sync (mask, val);
80- } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_or<Val>>::value) {
138+ } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_or<Val>>::value &&
139+ impl::has_or<Val>::value) {
81140 return __reduce_or_sync (mask, val);
82- } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_xor<Val>>::value) {
141+ } else if constexpr (__hip_internal::is_same<Op, cooperative_groups::bit_xor<Val>>::value &&
142+ impl::has_xor<Val>::value) {
83143 return __reduce_xor_sync (mask, val);
84144 } else {
85145 return __reduce_op_sync (mask, val, op, nullptr );
0 commit comments