Skip to content

Fold unaligned vec4 load and store into function #4684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 6 additions & 23 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,32 +1179,15 @@ def partial_rowwise_adam() -> Dict[str, Any]:
Vec4T<momentum1_ph_t> m_t;

if (enable_optimizer_offloading) {
// When offloading is enabled, we need to ensure proper alignment
// Create a temporary aligned array on the stack
alignas(16) momentum1_ph_t local_momentum1[4];

// Load values from momentum1_start into the aligned array
#pragma unroll
for (int i = 0; i < 4; i++) {
local_momentum1[i] = momentum1_start[d + i];
}

// Use the aligned array for computation
m_t = Vec4T<momentum1_ph_t>(local_momentum1);
// When offloading is enabled, we need to ensure proper alignment, so
// first copy to a temporary aligned array before loading to Vec4T
m_t = vec4_load_unaligned(momentum1_start + d);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);

// Store results back to the aligned array
m_t.store(local_momentum1);

// Copy results back to momentum1_start
#pragma unroll
for (int i = 0; i < 4; i++) {
momentum1_start[d + i] = local_momentum1[i];
}
vec4_store_unaligned(m_t, momentum1_start + d);

} else {
// When not offloading, we can directly use momentum1_start
// This avoids the extra copy operations and temporary array
// When offloading is not enabled, we can directly use momentum1_start
m_t = Vec4T<momentum1_ph_t>(&momentum1_start[d]);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/vec4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,36 @@ DEVICE_INLINE Vec4T<scalar_t> vec4_acc(
return s;
}

template <typename T>
DEVICE_INLINE Vec4T<T> vec4_load_unaligned(const T* src) {
// src is not guaranteed to have proper alignment.
// Create a temporary aligned array on the stack.
alignas(16) T temp[4];

// Load values from src into the byte-aligned array
#pragma unroll
for (auto i = 0; i < 4; i++) {
temp[i] = src[i];
}

// Then load the aligned array values into Vec4T
return Vec4T<T>(temp);
}

template <typename T>
DEVICE_INLINE void vec4_store_unaligned(const Vec4T<T>& vec, T* dst) {
// dst is not guaranteed to have proper alignment.
// Create a temporary aligned array on the stack.
alignas(16) T temp[4];

// Store Vec4T values into the byte-aligned array
vec.store(temp);

// Then store the aligned array values into dst
#pragma unroll
for (auto i = 0; i < 4; i++) {
dst[i] = temp[i];
}
}

} // namespace fbgemm_gpu
Loading