3
3
typedef void (*set_rows_kernel_t )(const char * src, char * dst);
4
4
5
5
template <typename src_t , typename dst_t >
6
- __device__ void set_rows_1 (const src_t * src_f, dst_t * dst_f) {
7
- GGML_ABORT (" unsupport type for set_rows" );
8
- }
6
+ __device__ void set_rows_1 (const src_t * src_f, dst_t * dst_f) {}
9
7
10
8
template <>
11
9
__device__ __forceinline__ void set_rows_1<float , half>(const float * src_f, half * dst_h) {
@@ -17,33 +15,34 @@ __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, fl
17
15
*dst_f = *src_f;
18
16
}
19
17
20
- // TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
21
18
template <typename src_t , typename dst_t >
22
19
static __global__ void k_set_rows (
23
20
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
24
21
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
25
22
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
26
23
const size_t nb01, const size_t nb02, const size_t nb03,
27
24
const size_t nb10, const size_t nb11, const size_t nb12,
28
- const size_t nb1, const size_t nb2, const size_t nb3,
29
- const size_t src_type_size, const size_t dst_type_size) {
25
+ const size_t nb1, const size_t nb2, const size_t nb3) {
30
26
31
- const int i03 = blockIdx .z / ne02;
32
- const int i02 = blockIdx .z % ne02;
33
- const int i01 = blockDim .x * blockIdx .x + threadIdx .x ;
34
- const int i00 = blockIdx .y ;
27
+ const int64_t i = blockDim .x * blockIdx .x + threadIdx .x ;
28
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
35
29
36
- if (i01 >= ne01 ) {
30
+ if (i >= ne_total ) {
37
31
return ;
38
32
}
39
33
40
- const int i12 = i03 % ne12;
41
- const int i11 = i02 % ne11;
42
- const int i10 = i01;
34
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
35
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
36
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
37
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
38
+
39
+ const int64_t i12 = i03 % ne12;
40
+ const int64_t i11 = i02 % ne11;
41
+ const int64_t i10 = i01;
43
42
44
43
const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12);
45
44
46
- const src_t * src0_row = ( const src_t *) src0 + i01*nb01 + i02*nb02 + i03*nb03;
45
+ const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
47
46
dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
48
47
49
48
const src_t * src_elem = src0_row + i00;
@@ -59,38 +58,32 @@ static void set_rows_cuda(
59
58
const size_t nb01, const size_t nb02, const size_t nb03,
60
59
const size_t nb10, const size_t nb11, const size_t nb12,
61
60
const size_t nb1, const size_t nb2, const size_t nb3,
62
- const size_t src_type_size, const size_t dst_type_size,
63
61
cudaStream_t stream) {
64
62
63
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
64
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1 ) / CUDA_SET_ROWS_BLOCK_SIZE;
65
65
const dim3 block_size (CUDA_SET_ROWS_BLOCK_SIZE);
66
- const dim3 grid_size (
67
- (ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1 )/CUDA_SET_ROWS_BLOCK_SIZE,
68
- ne00,
69
- ne03*ne02
70
- );
71
-
72
- const int s1 = nb01 / sizeof (src_t );
73
- const int s2 = nb02 / sizeof (src_t );
74
- const int s3 = nb03 / sizeof (src_t );
66
+ const dim3 grid_size (num_blocks);
75
67
76
- const int s10 = nb10 / sizeof (int64_t );
77
- const int s11 = nb11 / sizeof (int64_t );
78
- const int s12 = nb12 / sizeof (int64_t );
79
68
80
- const int s_dst = nb1 / sizeof (dst_t );
81
- const int s_dst2 = nb2 / sizeof (dst_t );
82
- const int s_dst3 = nb3 / sizeof (dst_t );
69
+ const int64_t s01 = nb01/sizeof (src_t );
70
+ const int64_t s02 = nb02/sizeof (src_t );
71
+ const int64_t s03 = nb03/sizeof (src_t );
72
+ const int64_t s10 = nb10/sizeof (int64_t );
73
+ const int64_t s11 = nb11/sizeof (int64_t );
74
+ const int64_t s12 = nb12/sizeof (int64_t );
75
+ const int64_t s1 = nb1/sizeof (dst_t );
76
+ const int64_t s2 = nb2/sizeof (dst_t );
77
+ const int64_t s3 = nb3/sizeof (dst_t );
83
78
84
-
85
- if (ne01 > 0 && ne00 > 0 ) {
79
+ if (ne_total > 0 ) {
86
80
k_set_rows<<<grid_size, block_size, 0 , stream>>> (
87
81
src0_d, src1_d, dst_d,
88
82
ne00, ne01, ne02, ne03,
89
83
ne10, ne11, ne12, ne13,
90
- s1, s2, s3 ,
84
+ s01, s02, s03 ,
91
85
s10, s11, s12,
92
- s_dst, s_dst2, s_dst3,
93
- src_type_size, dst_type_size);
86
+ s1, s2, s3);
94
87
}
95
88
}
96
89
@@ -109,6 +102,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
109
102
110
103
cudaStream_t stream = ctx.stream ();
111
104
105
+
106
+
112
107
if (dst->type == GGML_TYPE_F32) {
113
108
set_rows_cuda (
114
109
src0_d, src1_d, (float *)dst->data ,
@@ -117,7 +112,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
117
112
nb01, nb02, nb03,
118
113
nb10, nb11, nb12,
119
114
nb1, nb2, nb3,
120
- sizeof (float ), sizeof (float ),
121
115
stream
122
116
);
123
117
} else if (dst->type == GGML_TYPE_F16) {
@@ -128,7 +122,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128
122
nb01, nb02, nb03,
129
123
nb10, nb11, nb12,
130
124
nb1, nb2, nb3,
131
- sizeof (float ), sizeof (half),
132
125
stream
133
126
);
134
127
} else {
0 commit comments