Skip to content

Commit 0f46bd5

Browse files
authored
Add operator[] to vec class (#402)
Change-Id: I452cd5d1b258ccbc354a2f424d062f0fa91100ea
1 parent b5e9bc5 commit 0f46bd5

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

hipamd/include/hip/amd_detail/amd_hip_vector_types.h

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,45 @@ THE SOFTWARE.
5151
#include <type_traits>
5252
#endif // defined(__HIPCC_RTC__)
5353

54-
namespace hip_impl {
55-
inline
56-
constexpr
57-
unsigned int next_pot(unsigned int x) {
58-
// Precondition: x > 1.
59-
return 1u << (32u - __builtin_clz(x - 1u));
60-
}
61-
} // Namespace hip_impl.
54+
template <typename T, unsigned int n> struct HIP_vector_base;
55+
template <typename T, unsigned int rank> struct HIP_vector_type;
56+
57+
namespace hip_impl {
58+
inline constexpr unsigned int next_pot(unsigned int x) {
59+
// Precondition: x > 1.
60+
return 1u << (32u - __builtin_clz(x - 1u));
61+
}
62+
63+
template <typename T, unsigned int n>
64+
__attribute__((always_inline)) __HOST_DEVICE__ typename HIP_vector_base<T, n>::Native_vec_*
65+
get_native_pointer(HIP_vector_base<T, n>& base_vec) {
66+
static_assert(sizeof(base_vec) == sizeof(typename HIP_vector_base<T, n>::Native_vec_));
67+
static_assert(std::alignment_of<HIP_vector_base<T, n>>::value ==
68+
std::alignment_of<typename HIP_vector_base<T, n>::Native_vec_>::value);
69+
return reinterpret_cast<typename HIP_vector_base<T, n>::Native_vec_*>(&base_vec.x);
70+
};
6271

63-
template<typename T, unsigned int n> struct HIP_vector_base;
64-
template <typename T, unsigned int rank> struct HIP_vector_type;
72+
template <typename T, unsigned int n>
73+
__attribute__((always_inline)) __HOST_DEVICE__ const typename HIP_vector_base<T, n>::Native_vec_*
74+
get_native_pointer(const HIP_vector_base<T, n>& base_vec) {
75+
static_assert(sizeof(base_vec) == sizeof(typename HIP_vector_base<T, n>::Native_vec_));
76+
static_assert(std::alignment_of<HIP_vector_base<T, n>>::value ==
77+
std::alignment_of<typename HIP_vector_base<T, n>::Native_vec_>::value);
78+
return reinterpret_cast<const typename HIP_vector_base<T, n>::Native_vec_*>(&base_vec.x);
79+
};
80+
} // Namespace hip_impl.
6581

6682
template <typename T, unsigned int n>
6783
__attribute__((always_inline)) __HOST_DEVICE__ typename HIP_vector_base<T, n>::Native_vec_&
6884
get_native_vector(HIP_vector_base<T, n>& base_vec) {
69-
static_assert(sizeof(base_vec) == sizeof(typename HIP_vector_base<T, n>::Native_vec_));
70-
return *reinterpret_cast<typename HIP_vector_base<T, n>::Native_vec_*>(&base_vec.x);
85+
return *hip_impl::get_native_pointer(base_vec);
7186
};
7287

7388
template <typename T, unsigned int n>
7489
__attribute__((
7590
always_inline)) __HOST_DEVICE__ const typename HIP_vector_base<T, n>::Native_vec_&
7691
get_native_vector(const HIP_vector_base<T, n>& base_vec) {
77-
static_assert(sizeof(base_vec) == sizeof(typename HIP_vector_base<T, n>::Native_vec_));
78-
return *reinterpret_cast<const typename HIP_vector_base<T, n>::Native_vec_*>(&base_vec.x);
92+
return *hip_impl::get_native_pointer(base_vec);
7993
};
8094

8195
template<typename T>
@@ -349,6 +363,13 @@ THE SOFTWARE.
349363
HIP_vector_type& operator=(HIP_vector_type&&) = default;
350364

351365
// Operators
366+
__HOST_DEVICE__
367+
T& operator[](size_t idx) noexcept { return (*hip_impl::get_native_pointer(*this))[idx]; }
368+
__HOST_DEVICE__
369+
const T& operator[](size_t idx) const noexcept {
370+
return (*hip_impl::get_native_pointer(*this))[idx];
371+
}
372+
352373
__HOST_DEVICE__
353374
HIP_vector_type& operator++() noexcept {
354375
HIP_vector_type unity = make_vector_type<T, rank>(1);

0 commit comments

Comments
 (0)