33// SPDX-License-Identifier: Apache-2.0
44//
55
6+ #define unroll_for __attribute__((opencl_unroll_hint)) for
7+
68#if GATE_UP_ENABLE
79inline void gemv_n2x (const __global uchar * weight ,
810 __global half * scales ,
@@ -16,38 +18,16 @@ inline void gemv_n2x(const __global uchar* weight,
1618 int id_sg = get_sub_group_id ();
1719 int id_local = get_sub_group_local_id ();
1820
19- //# interleaving x into x2
20- half * px = x + id_sg * GROUP_SIZE ;
21- half * px2 = x2 + id_sg * GROUP_SIZE ;
22- for (int i = id_sg ; i < HIDDEN_SIZE /GROUP_SIZE ; i += num_sg , px += num_sg * GROUP_SIZE , px2 += num_sg * GROUP_SIZE ) {
23- //# quantization group
24- float x_group_sum = 0 ;
25- for (int j = id_local ; j < GROUP_SIZE /2 ; j += SUBGROUP_SIZE ) {
26- half even = px [2 * j + 0 ];
27- half odd = px [2 * j + 1 ];
28- px2 [j ] = even ;
29- px2 [j + GROUP_SIZE /2 ] = odd ;
30- x_group_sum += even + odd ;
31- }
32- x_group_sum = sub_group_reduce_add (x_group_sum );
33- if (id_local == 0 ) {
34- xg_sum [i ] = x_group_sum / SUBGROUP_SIZE ;
35- }
36- }
37-
38- barrier (CLK_LOCAL_MEM_FENCE );
39-
4021 int n_start = get_global_id (2 ) * N_BLOCK ;
4122 int n_end = n_start + N_BLOCK ;
42-
43- for (int n = n_start ; n < n_end ; n += 2 ) {
23+ unroll_for (int n = n_start ; n < n_end ; n += 2 ) {
4424 const __global uchar * B = weight + n * K / 2 ;
4525 float sum_all0 = 0 ;
4626 float sum_all1 = 0 ;
4727#if SZ_LAYOUT == 0
4828 __global half * S = scales + n ;
4929 __global uchar * Z = zps + n / 2 ;
50- for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ , S += N , Z += N / 2 ) {
30+ unroll_for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ , S += N , Z += N / 2 ) {
5131 half s0 = S [0 ];
5232 half s1 = S [1 ];
5333 ushort z = Z [0 ];
@@ -61,8 +41,7 @@ inline void gemv_n2x(const __global uchar* weight,
6141 uchar zp_values = intel_sub_group_block_read_uc ((const __global uchar * )Z );
6242 half zp_even = convert_half (zp_values & 0xF );
6343 half zp_odd = convert_half (zp_values >> 4 );
64-
65- for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ ) {
44+ unroll_for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ ) {
6645 half s0 = sub_group_broadcast (scale_values , 2 * gk + 0 );
6746 half s1 = sub_group_broadcast (scale_values , 2 * gk + 1 );
6847 half z_hf0 = sub_group_broadcast (zp_even , gk );
@@ -166,6 +145,29 @@ KERNEL (mlp_gate_up)(
166145
167146 __local half x2 [HIDDEN_SIZE ];
168147 __local float xg_sum [HIDDEN_SIZE /32 ];
148+ //# interleaving x into x2
149+ int id_sg = get_sub_group_id ();
150+ int num_sg = get_num_sub_groups ();
151+ int id_local = get_sub_group_local_id ();
152+ half * px = x + id_sg * GROUP_SIZE ;
153+ half * px2 = x2 + id_sg * GROUP_SIZE ;
154+ unroll_for (int i = id_sg ; i < HIDDEN_SIZE /GROUP_SIZE ; i += num_sg , px += num_sg * GROUP_SIZE , px2 += num_sg * GROUP_SIZE ) {
155+ //# quantization group
156+ float x_group_sum = 0 ;
157+ unroll_for (int j = id_local ; j < GROUP_SIZE /2 ; j += SUBGROUP_SIZE ) {
158+ half even = px [2 * j + 0 ];
159+ half odd = px [2 * j + 1 ];
160+ px2 [j ] = even ;
161+ px2 [j + GROUP_SIZE /2 ] = odd ;
162+ x_group_sum += even + odd ;
163+ }
164+ x_group_sum = sub_group_reduce_add (x_group_sum );
165+ if (id_local == 0 ) {
166+ xg_sum [i ] = x_group_sum / SUBGROUP_SIZE ;
167+ }
168+ }
169+ barrier (CLK_LOCAL_MEM_FENCE );
170+
169171 gemv_n2x (up_weight , up_scale , up_zp , x , y , INTERMEDIATE_SIZE , HIDDEN_SIZE , x2 , xg_sum , false);
170172 gemv_n2x (gate_weight , gate_scale , gate_zp , x , y , INTERMEDIATE_SIZE , HIDDEN_SIZE , x2 , xg_sum , true);
171173}
@@ -207,10 +209,10 @@ KERNEL (mlp_down)(
207209 //# interleaving x into x2
208210 __global half * px = x + id_sg * GROUP_SIZE ;
209211 __local half * px2 = x2 + id_sg * GROUP_SIZE ;
210- for (int i = id_sg ; i < INTERMEDIATE_SIZE /GROUP_SIZE ; i += num_sg , px += num_sg * GROUP_SIZE , px2 += num_sg * GROUP_SIZE ) {
212+ unroll_for (int i = id_sg ; i < INTERMEDIATE_SIZE /GROUP_SIZE ; i += num_sg , px += num_sg * GROUP_SIZE , px2 += num_sg * GROUP_SIZE ) {
211213 //# quantization group
212214 float x_group_sum = 0 ;
213- for (int j = id_local ; j < GROUP_SIZE /2 ; j += SUBGROUP_SIZE ) {
215+ unroll_for (int j = id_local ; j < GROUP_SIZE /2 ; j += SUBGROUP_SIZE ) {
214216 half even = px [2 * j + 0 ];
215217 half odd = px [2 * j + 1 ];
216218 px2 [j ] = even ;
@@ -227,13 +229,13 @@ KERNEL (mlp_down)(
227229 int n_start = get_global_id (2 ) * N_BLOCK ;
228230 int n_end = n_start + N_BLOCK ;
229231
230- for (int n = n_start ; n < n_end ; n += 2 ) {
232+ unroll_for (int n = n_start ; n < n_end ; n += 2 ) {
231233 const __global uchar * B = weight + n * K / 2 ;
232234 __global half * S = scales + n ;
233235 __global uchar * Z = zps + n / 2 ;
234236 float sum_all0 = 0 ;
235237 float sum_all1 = 0 ;
236- for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ , S += N , Z += N / 2 ) {
238+ unroll_for (int gk = 0 ; gk < K / GROUP_SIZE ; gk ++ , S += N , Z += N / 2 ) {
237239 half s0 = S [0 ];
238240 half s1 = S [1 ];
239241 ushort z = Z [0 ];
0 commit comments