Skip to content

Commit 01fc359

Browse files
committed
Use 1-d kernel
1 parent feaeda0 commit 01fc359

File tree

2 files changed

+32
-39
lines changed

2 files changed

+32
-39
lines changed

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
44

55
template<typename src_t, typename dst_t>
6-
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
7-
GGML_ABORT("unsupport type for set_rows");
8-
}
6+
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
97

108
template<>
119
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
@@ -17,33 +15,34 @@ __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, fl
1715
*dst_f = *src_f;
1816
}
1917

20-
//TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
2118
template<typename src_t, typename dst_t>
2219
static __global__ void k_set_rows(
2320
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
2421
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
2522
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
2623
const size_t nb01, const size_t nb02, const size_t nb03,
2724
const size_t nb10, const size_t nb11, const size_t nb12,
28-
const size_t nb1, const size_t nb2, const size_t nb3,
29-
const size_t src_type_size, const size_t dst_type_size) {
25+
const size_t nb1, const size_t nb2, const size_t nb3) {
3026

31-
const int i03 = blockIdx.z / ne02;
32-
const int i02 = blockIdx.z % ne02;
33-
const int i01 = blockDim.x * blockIdx.x + threadIdx.x;
34-
const int i00 = blockIdx.y;
27+
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
28+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
3529

36-
if (i01 >= ne01) {
30+
if (i >= ne_total) {
3731
return;
3832
}
3933

40-
const int i12 = i03 % ne12;
41-
const int i11 = i02 % ne11;
42-
const int i10 = i01;
34+
const int64_t i03 = i / (ne00 * ne01 * ne02);
35+
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
36+
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
37+
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
38+
39+
const int64_t i12 = i03 % ne12;
40+
const int64_t i11 = i02 % ne11;
41+
const int64_t i10 = i01;
4342

4443
const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12);
4544

46-
const src_t * src0_row = (const src_t *)src0 + i01*nb01 + i02*nb02 + i03*nb03;
45+
const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
4746
dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
4847

4948
const src_t* src_elem = src0_row + i00;
@@ -59,38 +58,32 @@ static void set_rows_cuda(
5958
const size_t nb01, const size_t nb02, const size_t nb03,
6059
const size_t nb10, const size_t nb11, const size_t nb12,
6160
const size_t nb1, const size_t nb2, const size_t nb3,
62-
const size_t src_type_size, const size_t dst_type_size,
6361
cudaStream_t stream) {
6462

63+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
64+
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
6565
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
66-
const dim3 grid_size(
67-
(ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1)/CUDA_SET_ROWS_BLOCK_SIZE,
68-
ne00,
69-
ne03*ne02
70-
);
71-
72-
const int s1 = nb01 / sizeof(src_t);
73-
const int s2 = nb02 / sizeof(src_t);
74-
const int s3 = nb03 / sizeof(src_t);
66+
const dim3 grid_size(num_blocks);
7567

76-
const int s10 = nb10 / sizeof(int64_t);
77-
const int s11 = nb11 / sizeof(int64_t);
78-
const int s12 = nb12 / sizeof(int64_t);
7968

80-
const int s_dst = nb1 / sizeof(dst_t);
81-
const int s_dst2 = nb2 / sizeof(dst_t);
82-
const int s_dst3 = nb3 / sizeof(dst_t);
69+
const int64_t s01 = nb01/sizeof(src_t);
70+
const int64_t s02 = nb02/sizeof(src_t);
71+
const int64_t s03 = nb03/sizeof(src_t);
72+
const int64_t s10 = nb10/sizeof(int64_t);
73+
const int64_t s11 = nb11/sizeof(int64_t);
74+
const int64_t s12 = nb12/sizeof(int64_t);
75+
const int64_t s1 = nb1/sizeof(dst_t);
76+
const int64_t s2 = nb2/sizeof(dst_t);
77+
const int64_t s3 = nb3/sizeof(dst_t);
8378

84-
85-
if(ne01 > 0 && ne00 > 0) {
79+
if (ne_total > 0) {
8680
k_set_rows<<<grid_size, block_size, 0, stream>>>(
8781
src0_d, src1_d, dst_d,
8882
ne00, ne01, ne02, ne03,
8983
ne10, ne11, ne12, ne13,
90-
s1, s2, s3,
84+
s01, s02, s03,
9185
s10, s11, s12,
92-
s_dst, s_dst2, s_dst3,
93-
src_type_size, dst_type_size);
86+
s1, s2, s3);
9487
}
9588
}
9689

@@ -109,6 +102,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
109102

110103
cudaStream_t stream = ctx.stream();
111104

105+
106+
112107
if (dst->type == GGML_TYPE_F32) {
113108
set_rows_cuda(
114109
src0_d, src1_d, (float*)dst->data,
@@ -117,7 +112,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
117112
nb01, nb02, nb03,
118113
nb10, nb11, nb12,
119114
nb1, nb2, nb3,
120-
sizeof(float), sizeof(float),
121115
stream
122116
);
123117
} else if (dst->type == GGML_TYPE_F16) {
@@ -128,7 +122,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128122
nb01, nb02, nb03,
129123
nb10, nb11, nb12,
130124
nb1, nb2, nb3,
131-
sizeof(float), sizeof(half),
132125
stream
133126
);
134127
} else {

ggml/src/ggml-cuda/set-rows.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
#include "common.cuh"
44

5-
#define CUDA_SET_ROWS_BLOCK_SIZE 64
5+
#define CUDA_SET_ROWS_BLOCK_SIZE 256
66

77
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)