Skip to content

SWDEV-546808 optimize match_any() with dpp wave_rot1 #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions hipamd/include/hip/amd_detail/amd_warp_sync_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,37 @@ unsigned long long __match_any(T value) {
(sizeof(T) == 4 || sizeof(T) == 8),
"T can be int, unsigned int, long, unsigned long, long long, unsigned "
"long long, float or double.");
bool done = false;
unsigned long long retval = 0;

while (__any(!done)) {
if (!done) {
T chosen = __hip_readfirstlane(value);
if (chosen == value) {
retval = __activemask();
done = true;
unsigned long long actvmask = __activemask();
unsigned long long retval = 0;
if (actvmask != ~0ull) {
bool done = false;
while (__any(!done)) {
if (!done) {
T chosen = __hip_readfirstlane(value);
if (chosen == value) {
retval = __activemask();
done = true;
}
}
}
} else {
union dill { unsigned int i[2]; unsigned long long ill; decltype(value) val; } dill_ = { .val = value };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we push this union declaration into the if constexpr, then we will only need one integer member with the correct width.

Copy link
Author

@amd-hhashemi amd-hhashemi Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't match_any() need to handle int and long?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was kinda assuming the compiler will optimize out the upper bits since the whole function is a template anyway. But this works too.

retval = 1;
//Do a full rotate of the wave lanes, using dpp with "wave_rol1" control (ID: 0x134).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//Do a full rotate of the wave lanes, using dpp with "wave_rol1" control (ID: 0x134).
// Do a full rotate of the wave lanes, using dpp with "wave_rol1" control (ID: 0x134).

Please add a space after the // for every comment line. I don't know if HIP has yet adopted a coding style, but this kind of whitespace is very commonly considered correct.

//wave_rol1 dpp rotates the value from each lane to one lane left of it, across the whole wave.
//In doing so each lane gets a mask of matches with all other lanes in the wave in retval.
for (int i = 1; i < static_cast<int>(warpSize); i++) {
if constexpr (sizeof(value) == 8)
dill_.ill = __builtin_amdgcn_mov_dpp(dill_.ill, 0x134 /*dpp_ctrl=wave_rol1*/, 0xf/*row_mask*/, 0xf/*clmn_mask*/, 1/*bound_ctrl*/);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments are partially helpful, but I am still not okay with magic constants appearing somewhere in code. Those constants like 0x134 need to be declared in one suitable place, maybe as an enum or as constant integers with meaningful names.

else
dill_.i[0] = __builtin_amdgcn_mov_dpp(dill_.i[0], 0x134 /*dpp_ctrl=wave_rol1*/, 0xf/*full*/, 0xf/*full*/, 1/*bound_ctrl*/);
retval |= ((unsigned long long)(dill_.val == value)) << i;
}
//At this point each lane has a rotated match_any mask, where it is in the LSB.
//So we just need to rotate the mask by the lane's actual position to get the correct mask.
int rotv = __lane_id();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does rotv mean? More comments will be useful for future reference, like what is the overall strategy here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the rotate amount. I added comment to code.

retval = (retval << rotv) | (retval >> (static_cast<int>(warpSize) - rotv));
}

return retval;
Expand Down
Loading