Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 80 additions & 39 deletions csrc/xpu/gdn_attn/causal_conv1d.hpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace gdn {

template <typename T, int Width>
template <typename T, int Width, bool ReorderInput>
struct causal_conv1d_kernel {
public:
static constexpr int sub_group_size = 32;
Expand Down Expand Up @@ -105,16 +105,29 @@ struct causal_conv1d_kernel {
int qkvz_dim_id = qkvz_elems_id % qkvz_dim;

// reorder b,a
if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
int step =
token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
const int ba_elems_per_item =
sycl::min(elems_per_item, num_v_heads / num_k_heads);
if constexpr (ReorderInput) {
if (qkvz_elems_id < num_v_heads) {
int step = token_id * num_v_heads;
#pragma unroll
for (int e = 0; e < ba_elems_per_item; ++e) {
b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
a_out[step + qkvz_dim_id + e] =
mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
for (int e = 0; e < elems_per_item; ++e) {
b_out[step + qkvz_elems_id + e] =
mixed_ba[step * 2 + qkvz_elems_id + e];
a_out[step + qkvz_elems_id + e] =
mixed_ba[step * 2 + num_v_heads + qkvz_dim_id + e];
}
}
} else {
if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
int step =
token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
const int ba_elems_per_item =
sycl::min(elems_per_item, num_v_heads / num_k_heads);
#pragma unroll
for (int e = 0; e < ba_elems_per_item; ++e) {
b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
a_out[step + qkvz_dim_id + e] =
mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
}
}
}

Expand All @@ -138,19 +151,37 @@ struct causal_conv1d_kernel {
return;
}

int mixed_qkvz_id = qkvz_elems_id;

bool is_q = false;
bool is_k = false;
bool is_v = false;
bool is_z = false;

if (qkvz_dim_id < q_dim) {
is_q = true;
if constexpr (ReorderInput) {
mixed_qkvz_id = k_heads_id * k_dim + qkvz_dim_id;
}
} else if (qkvz_dim_id < q_dim + k_dim) {
is_k = true;
if constexpr (ReorderInput) {
mixed_qkvz_id = num_k_heads * head_k_dim + k_heads_id * k_dim +
qkvz_dim_id - (q_dim);
}
} else if (qkvz_dim_id < q_dim + k_dim + v_dim) {
is_v = true;
if constexpr (ReorderInput) {
mixed_qkvz_id = 2 * num_k_heads * head_k_dim + k_heads_id * v_dim +
qkvz_dim_id - (q_dim + k_dim);
}
} else {
is_z = true;
if constexpr (ReorderInput) {
mixed_qkvz_id = 2 * num_k_heads * head_k_dim +
num_v_heads * head_v_dim + k_heads_id * z_dim +
qkvz_dim_id - (q_dim + k_dim + v_dim);
}
}

// reorder z
Expand All @@ -160,7 +191,7 @@ struct causal_conv1d_kernel {
#pragma unroll
for (int e = 0; e < elems_per_item; ++e) {
z_out[token_id * num_k_heads * z_dim + z_elems_id + e] =
mixed_qkvz[token_id * qkvz_elems + qkvz_elems_id + e];
mixed_qkvz[token_id * qkvz_elems + mixed_qkvz_id + e];
}
return;
}
Expand Down Expand Up @@ -224,7 +255,7 @@ struct causal_conv1d_kernel {
#pragma unroll
for (int e = 0; e < elems_per_item; ++e) {
local_input[Width * e + states_load_len + i] = mixed_qkvz
[(token_id - input_load_len + 1 + i) * qkvz_elems + qkvz_elems_id +
[(token_id - input_load_len + 1 + i) * qkvz_elems + mixed_qkvz_id +
e];
}
}
Expand Down Expand Up @@ -416,7 +447,7 @@ struct update_states_kernel {
const int batch_size;
};

template <typename T, int Width>
template <typename T, int Width, bool ReorderInput>
void kernel_launcher(
sycl::queue& queue,
T* q_out,
Expand Down Expand Up @@ -447,9 +478,10 @@ void kernel_launcher(
const int& conv_elems,
const int& num_prefills,
const int& num_decodes) {
using KERNEL_MAIN = causal_conv1d_kernel<T, Width>;
using KERNEL_MAIN = causal_conv1d_kernel<T, Width, ReorderInput>;
auto range_main = KERNEL_MAIN::get_nd_range(num_actual_tokens, qkvz_elems);
assert(head_k_dim % KERNEL_MAIN::elems_per_item == 0);
assert(num_v_heads % KERNEL_MAIN::elems_per_item == 0);
queue.submit([&](sycl::handler& cgh) {
KERNEL_MAIN task(
q_out,
Expand Down Expand Up @@ -528,7 +560,8 @@ void causal_conv1d(
const ActMode& act_mode, // silu or swish
const int& pad_slot_id, // -1
const int num_prefills,
const int num_decodes) {
const int num_decodes,
const bool reorder_input) {
if (num_prefills == 0 && num_decodes == 0) {
return;
}
Expand All @@ -550,8 +583,8 @@ void causal_conv1d(
{batch_size, width - 1, conv_elems},
torch::dtype(dtype).device(device).requires_grad(false));

#define KERNEL_LAUNCHER(scalar_t, width) \
kernel_launcher<scalar_t, width>( \
#define KERNEL_LAUNCHER(scalar_t, width, reorder_input) \
kernel_launcher<scalar_t, width, reorder_input>( \
queue, \
reinterpret_cast<scalar_t*>(q_out.data_ptr()), \
reinterpret_cast<scalar_t*>(k_out.data_ptr()), \
Expand Down Expand Up @@ -586,37 +619,45 @@ void causal_conv1d(
num_prefills, \
num_decodes);

#define WIDTH_DISPATCH(scalar_t, width) \
switch (width) { \
case 1: \
KERNEL_LAUNCHER(scalar_t, 1) \
break; \
case 2: \
KERNEL_LAUNCHER(scalar_t, 2) \
break; \
case 3: \
KERNEL_LAUNCHER(scalar_t, 3) \
break; \
case 4: \
KERNEL_LAUNCHER(scalar_t, 4) \
break; \
case 5: \
KERNEL_LAUNCHER(scalar_t, 5) \
break; \
default: \
break; \
#define WIDTH_DISPATCH(scalar_t, width, reorder_input) \
switch (width) { \
case 1: \
KERNEL_LAUNCHER(scalar_t, 1, reorder_input) \
break; \
case 2: \
KERNEL_LAUNCHER(scalar_t, 2, reorder_input) \
break; \
case 3: \
KERNEL_LAUNCHER(scalar_t, 3, reorder_input) \
break; \
case 4: \
KERNEL_LAUNCHER(scalar_t, 4, reorder_input) \
break; \
case 5: \
KERNEL_LAUNCHER(scalar_t, 5, reorder_input) \
break; \
default: \
break; \
}

#define SPLIT_DISPATCH(scalar_t, width, reorder_input) \
if (reorder_input) { \
WIDTH_DISPATCH(scalar_t, width, true) \
} else { \
WIDTH_DISPATCH(scalar_t, width, false) \
}

if (mixed_qkvz.scalar_type() == at::kBFloat16) {
using scalar_t = sycl::ext::oneapi::bfloat16;
WIDTH_DISPATCH(scalar_t, width)
SPLIT_DISPATCH(scalar_t, width, reorder_input)
} else if (mixed_qkvz.scalar_type() == at::kHalf) {
using scalar_t = sycl::half;
WIDTH_DISPATCH(scalar_t, width)
SPLIT_DISPATCH(scalar_t, width, reorder_input)
} else {
using scalar_t = float;
WIDTH_DISPATCH(scalar_t, width)
SPLIT_DISPATCH(scalar_t, width, reorder_input)
}
#undef SPLIT_DISPATCH
#undef WIDTH_DISPATCH
#undef KERNEL_LAUNCHER
}
Expand Down
9 changes: 6 additions & 3 deletions csrc/xpu/gdn_attn/gdn_attn_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ void gdn_attention(
const torch::Tensor& non_spec_query_start_loc, // [batch_size + 1]
const torch::Tensor& non_spec_state_indices_tensor, // [batch_size]
const int64_t num_actual_tokens,
const int64_t tp_size) {
const int64_t tp_size,
const bool reorder_input) {
TORCH_CHECK(
core_attn_out.is_contiguous(), "core_attn_out must be contiguous");
TORCH_CHECK(z.is_contiguous(), "z must be contiguous");
Expand Down Expand Up @@ -144,7 +145,8 @@ void gdn_attention(
act_mode, \
pad_slot_id, \
num_prefills, \
num_decodes); \
num_decodes, \
reorder_input); \
gdn::gated_delta_rule( \
queue, \
core_attn_out, \
Expand Down Expand Up @@ -203,7 +205,8 @@ void gdn_attention(
act_mode,
pad_slot_id,
num_prefills,
num_decodes);
num_decodes,
reorder_input);

chunk_gated_delta_rule_xe2(
queue,
Expand Down
Loading
Loading