Skip to content

Commit 39601a4

Browse files
committed
Update moe kernel
1 parent 7b36b59 commit 39601a4

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55

6+
#define unroll_for __attribute__((opencl_unroll_hint)) for
7+
68
#if GATE_UP_ENABLE
79
inline void gemv_n2x(const __global uchar* weight,
810
__global half* scales,
@@ -16,38 +18,16 @@ inline void gemv_n2x(const __global uchar* weight,
1618
int id_sg = get_sub_group_id();
1719
int id_local = get_sub_group_local_id();
1820

19-
//# interleaving x into x2
20-
half * px = x + id_sg*GROUP_SIZE;
21-
half * px2 = x2 + id_sg*GROUP_SIZE;
22-
for(int i = id_sg; i < HIDDEN_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) {
23-
//# quantization group
24-
float x_group_sum = 0;
25-
for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) {
26-
half even = px[2*j + 0];
27-
half odd = px[2*j + 1];
28-
px2[j] = even;
29-
px2[j + GROUP_SIZE/2] = odd;
30-
x_group_sum += even + odd;
31-
}
32-
x_group_sum = sub_group_reduce_add(x_group_sum);
33-
if (id_local == 0) {
34-
xg_sum[i] = x_group_sum / SUBGROUP_SIZE;
35-
}
36-
}
37-
38-
barrier(CLK_LOCAL_MEM_FENCE);
39-
4021
int n_start = get_global_id(2) * N_BLOCK;
4122
int n_end = n_start + N_BLOCK;
42-
43-
for (int n = n_start; n < n_end; n+=2) {
23+
unroll_for (int n = n_start; n < n_end; n+=2) {
4424
const __global uchar* B = weight + n * K / 2;
4525
float sum_all0 = 0;
4626
float sum_all1 = 0;
4727
#if SZ_LAYOUT == 0
4828
__global half* S = scales + n;
4929
__global uchar* Z = zps + n / 2;
50-
for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) {
30+
unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) {
5131
half s0 = S[0];
5232
half s1 = S[1];
5333
ushort z = Z[0];
@@ -61,8 +41,7 @@ inline void gemv_n2x(const __global uchar* weight,
6141
uchar zp_values = intel_sub_group_block_read_uc((const __global uchar*)Z);
6242
half zp_even = convert_half(zp_values & 0xF);
6343
half zp_odd = convert_half(zp_values >> 4);
64-
65-
for (int gk = 0; gk < K / GROUP_SIZE; gk++) {
44+
unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++) {
6645
half s0 = sub_group_broadcast(scale_values, 2*gk + 0);
6746
half s1 = sub_group_broadcast(scale_values, 2*gk + 1);
6847
half z_hf0 = sub_group_broadcast(zp_even, gk);
@@ -166,6 +145,29 @@ KERNEL (mlp_gate_up)(
166145

167146
__local half x2[HIDDEN_SIZE];
168147
__local float xg_sum[HIDDEN_SIZE/32];
148+
//# interleaving x into x2
149+
int id_sg = get_sub_group_id();
150+
int num_sg = get_num_sub_groups();
151+
int id_local = get_sub_group_local_id();
152+
half * px = x + id_sg*GROUP_SIZE;
153+
half * px2 = x2 + id_sg*GROUP_SIZE;
154+
unroll_for(int i = id_sg; i < HIDDEN_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) {
155+
//# quantization group
156+
float x_group_sum = 0;
157+
unroll_for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) {
158+
half even = px[2*j + 0];
159+
half odd = px[2*j + 1];
160+
px2[j] = even;
161+
px2[j + GROUP_SIZE/2] = odd;
162+
x_group_sum += even + odd;
163+
}
164+
x_group_sum = sub_group_reduce_add(x_group_sum);
165+
if (id_local == 0) {
166+
xg_sum[i] = x_group_sum / SUBGROUP_SIZE;
167+
}
168+
}
169+
barrier(CLK_LOCAL_MEM_FENCE);
170+
169171
gemv_n2x(up_weight, up_scale, up_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, false);
170172
gemv_n2x(gate_weight, gate_scale, gate_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, true);
171173
}
@@ -207,10 +209,10 @@ KERNEL (mlp_down)(
207209
//# interleaving x into x2
208210
__global half * px = x + id_sg*GROUP_SIZE;
209211
__local half * px2 = x2 + id_sg*GROUP_SIZE;
210-
for(int i = id_sg; i < INTERMEDIATE_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) {
212+
unroll_for(int i = id_sg; i < INTERMEDIATE_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) {
211213
//# quantization group
212214
float x_group_sum = 0;
213-
for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) {
215+
unroll_for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) {
214216
half even = px[2*j + 0];
215217
half odd = px[2*j + 1];
216218
px2[j] = even;
@@ -227,13 +229,13 @@ KERNEL (mlp_down)(
227229
int n_start = get_global_id(2) * N_BLOCK;
228230
int n_end = n_start + N_BLOCK;
229231

230-
for (int n = n_start; n < n_end; n+=2) {
232+
unroll_for (int n = n_start; n < n_end; n+=2) {
231233
const __global uchar* B = weight + n * K / 2;
232234
__global half* S = scales + n;
233235
__global uchar* Z = zps + n / 2;
234236
float sum_all0 = 0;
235237
float sum_all1 = 0;
236-
for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) {
238+
unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) {
237239
half s0 = S[0];
238240
half s1 = S[1];
239241
ushort z = Z[0];

src/plugins/intel_gpu/src/plugin/ops/moe.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
#include "intel_gpu/plugin/common_utils.hpp"
77
#include "intel_gpu/plugin/program_builder.hpp"
88
#include "intel_gpu/primitives/moe_fused_compressed.hpp"
9-
#include "intel_gpu/primitives/moe_fused_compressed.hpp"
10-
119

1210
namespace ov {
1311
namespace op {
1412
namespace internal {
15-
using MOEFusedCompressed = ov::intel_gpu::op::MOEFusedCompressed ;
13+
using MOEFusedCompressed = ov::intel_gpu::op::MOEFusedCompressed;
1614
} // namespace internal
1715
} // namespace op
1816
} // namespace ov

0 commit comments

Comments
 (0)