Skip to content

Commit 7569423

Browse files
committed
Revert CUDA: add fused rms norm (ggml-org#14800)
1 parent 13e017e commit 7569423

File tree

10 files changed

+2
-247
lines changed

10 files changed

+2
-247
lines changed

ggml/include/ggml.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,6 @@ extern "C" {
590590
GGML_OP_RMS_NORM,
591591
GGML_OP_RMS_NORM_BACK,
592592
GGML_OP_GROUP_NORM,
593-
GGML_OP_FUSED_RMS_NORM,
594593
GGML_OP_FUSED_MUL_UNARY,
595594
GGML_OP_MULTI_ADD,
596595
GGML_OP_L2_NORM,
@@ -1375,18 +1374,6 @@ extern "C" {
13751374
struct ggml_tensor * a,
13761375
float eps);
13771376

1378-
GGML_API struct ggml_tensor * ggml_fused_rms_norm(
1379-
struct ggml_context * ctx,
1380-
struct ggml_tensor * a,
1381-
struct ggml_tensor * b,
1382-
float eps);
1383-
1384-
GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
1385-
struct ggml_context * ctx,
1386-
struct ggml_tensor * a,
1387-
struct ggml_tensor * b,
1388-
float eps);
1389-
13901377
// group normalize along ne0*ne1*n_groups
13911378
// used in stable-diffusion
13921379
GGML_API struct ggml_tensor * ggml_group_norm(

ggml/src/ggml-alloc.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
4343
case GGML_OP_SILU_BACK:
4444
case GGML_OP_RMS_NORM:
4545
case GGML_OP_RMS_NORM_BACK:
46-
case GGML_OP_FUSED_RMS_NORM:
4746
case GGML_OP_SOFT_MAX:
4847
case GGML_OP_SOFT_MAX_BACK:
4948
case GGML_OP_SOFTCAP:

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,10 +2510,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
25102510
{
25112511
ggml_compute_forward_rms_norm(params, tensor);
25122512
} break;
2513-
case GGML_OP_FUSED_RMS_NORM:
2514-
{
2515-
ggml_compute_forward_fused_rms_norm(params, tensor);
2516-
} break;
25172513
case GGML_OP_FUSED_MUL_UNARY:
25182514
{
25192515
ggml_compute_forward_fused_mul_unary(params, tensor);
@@ -2963,7 +2959,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
29632959
case GGML_OP_DIV:
29642960
case GGML_OP_NORM:
29652961
case GGML_OP_RMS_NORM:
2966-
case GGML_OP_FUSED_RMS_NORM:
29672962
case GGML_OP_FUSED_MUL_UNARY:
29682963
case GGML_OP_RMS_NORM_BACK:
29692964
case GGML_OP_L2_NORM:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4706,78 +4706,6 @@ void ggml_compute_forward_rms_norm(
47064706
}
47074707
}
47084708

4709-
static void ggml_compute_forward_fused_rms_norm_f32(
4710-
const struct ggml_compute_params * params,
4711-
struct ggml_tensor * dst) {
4712-
4713-
const struct ggml_tensor * src0 = dst->src[0];
4714-
const struct ggml_tensor * src1 = dst->src[1];
4715-
4716-
if (!src1) {
4717-
ggml_compute_forward_rms_norm_f32(params, dst);
4718-
return;
4719-
}
4720-
4721-
GGML_ASSERT(ggml_are_same_shape(src0, dst));
4722-
4723-
GGML_ASSERT(src0->nb[0] == sizeof(float));
4724-
GGML_ASSERT(src1->nb[0] == sizeof(float));
4725-
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
4726-
GGML_ASSERT(ggml_nrows(src1) == 1);
4727-
4728-
const int ith = params->ith;
4729-
const int nth = params->nth;
4730-
4731-
GGML_TENSOR_UNARY_OP_LOCALS
4732-
4733-
float eps;
4734-
memcpy(&eps, dst->op_params, sizeof(float));
4735-
4736-
GGML_ASSERT(eps > 0.0f);
4737-
4738-
// TODO: optimize
4739-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4740-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4741-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4742-
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4743-
4744-
ggml_float sum = 0.0;
4745-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4746-
sum += (ggml_float)(x[i00] * x[i00]);
4747-
}
4748-
4749-
const float mean = sum/ne00;
4750-
4751-
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4752-
4753-
const float scale = 1.0f/sqrtf(mean + eps);
4754-
4755-
ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
4756-
ggml_vec_scale_f32(ne00, y, scale);
4757-
4758-
}
4759-
}
4760-
}
4761-
}
4762-
4763-
void ggml_compute_forward_fused_rms_norm(
4764-
const struct ggml_compute_params * params,
4765-
struct ggml_tensor * dst) {
4766-
4767-
const struct ggml_tensor * src0 = dst->src[0];
4768-
4769-
switch (src0->type) {
4770-
case GGML_TYPE_F32:
4771-
{
4772-
ggml_compute_forward_fused_rms_norm_f32(params, dst);
4773-
} break;
4774-
default:
4775-
{
4776-
GGML_ABORT("fatal error");
4777-
}
4778-
}
4779-
}
4780-
47814709
// ggml_compute_forward_fused_mul_unary
47824710

47834711
static void ggml_compute_forward_fused_mul_unary_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru
4545
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4646
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4747
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
48-
void ggml_compute_forward_fused_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4948
void ggml_compute_forward_fused_mul_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5049
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5150
void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,9 +2612,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
26122612
case GGML_OP_RMS_NORM_BACK:
26132613
ggml_cuda_op_rms_norm_back(ctx, dst);
26142614
break;
2615-
case GGML_OP_FUSED_RMS_NORM:
2616-
ggml_cuda_op_fused_rms_norm(ctx, dst);
2617-
break;
26182615
case GGML_OP_MUL_MAT:
26192616
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
26202617
break;
@@ -3663,9 +3660,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
36633660
case GGML_OP_RMS_NORM_BACK:
36643661
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
36653662
break;
3666-
case GGML_OP_FUSED_RMS_NORM:
3667-
// return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3668-
// break;
36693663
case GGML_OP_NONE:
36703664
case GGML_OP_RESHAPE:
36713665
case GGML_OP_VIEW:

ggml/src/ggml-cuda/norm.cu

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -215,40 +215,6 @@ static __global__ void rms_norm_back_f32(
215215
}
216216
}
217217

218-
template <int block_size>
219-
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
220-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
221-
const int tid = threadIdx.x;
222-
223-
float tmp = 0.0f; // partial sum for thread in warp
224-
225-
for (int col = tid; col < ncols; col += block_size) {
226-
const float xi = x[row*ncols + col];
227-
tmp += xi * xi;
228-
}
229-
230-
// sum up partial sums
231-
tmp = warp_reduce_sum(tmp);
232-
if (block_size > WARP_SIZE) {
233-
__shared__ float s_sum[32];
234-
int warp_id = threadIdx.x / WARP_SIZE;
235-
int lane_id = threadIdx.x % WARP_SIZE;
236-
if (lane_id == 0) {
237-
s_sum[warp_id] = tmp;
238-
}
239-
__syncthreads();
240-
tmp = s_sum[lane_id];
241-
tmp = warp_reduce_sum(tmp);
242-
}
243-
244-
const float mean = tmp / ncols;
245-
const float scale = rsqrtf(mean + eps);
246-
247-
for (int col = tid; col < ncols; col += block_size) {
248-
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
249-
}
250-
}
251-
252218
// template <int block_size>
253219
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
254220
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -395,19 +361,6 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
395361
}
396362
}
397363

398-
399-
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
400-
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
401-
GGML_ASSERT(ncols % WARP_SIZE == 0);
402-
if (ncols < 1024) {
403-
const dim3 block_dims(WARP_SIZE, 1, 1);
404-
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
405-
} else {
406-
const dim3 block_dims(1024, 1, 1);
407-
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
408-
}
409-
}
410-
411364
static void l2_norm_f32_cuda(
412365
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
413366
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -567,36 +520,6 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
567520
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
568521
}
569522

570-
571-
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
572-
if (!dst->src[1]) {
573-
ggml_cuda_op_rms_norm(ctx, dst);
574-
return;
575-
}
576-
const ggml_tensor * src0 = dst->src[0];
577-
const ggml_tensor * src1 = dst->src[1];
578-
const float * src0_d = (const float *)src0->data;
579-
const float * src1_d = (const float *)src1->data;
580-
float * dst_d = (float *)dst->data;
581-
cudaStream_t stream = ctx.stream();
582-
583-
GGML_ASSERT(ggml_is_contiguous(src0));
584-
585-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
586-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
587-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
588-
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
589-
GGML_ASSERT(ggml_nrows(src1) == 1);
590-
591-
const int64_t ne00 = src0->ne[0];
592-
const int64_t nrows = ggml_nrows(src0);
593-
594-
float eps;
595-
memcpy(&eps, dst->op_params, sizeof(float));
596-
597-
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
598-
}
599-
600523
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
601524
const ggml_tensor * src0 = dst->src[0];
602525
const float * src0_d = (const float *) src0->data;

ggml/src/ggml-cuda/norm.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,4 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
1010

1111
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1212

13-
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
14-
1513
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
18581858
"RMS_NORM",
18591859
"RMS_NORM_BACK",
18601860
"GROUP_NORM",
1861-
"FUSED_RMS_NORM",
18621861
"FUSED_MUL_UNARY",
18631862
"MULTI_ADD",
18641863
"L2_NORM",
@@ -1933,7 +1932,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
19331932
"GLU",
19341933
};
19351934

1936-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1935+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
19371936

19381937
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
19391938
"none",
@@ -1963,7 +1962,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
19631962
"rms_norm(x)",
19641963
"rms_norm_back(x)",
19651964
"group_norm(x)",
1966-
"fused_rms_norm(x)",
19671965
"fused_mul_unary(x)",
19681966
"x1+x2+x3+...",
19691967
"l2_norm(x)",
@@ -2038,7 +2036,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
20382036
"glu(x)",
20392037
};
20402038

2041-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
2039+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
20422040

20432041
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
20442042

@@ -3948,57 +3946,6 @@ struct ggml_tensor * ggml_rms_norm_inplace(
39483946
return ggml_rms_norm_impl(ctx, a, eps, true);
39493947
}
39503948

3951-
static struct ggml_tensor * ggml_fused_rms_norm_impl(
3952-
struct ggml_context * ctx,
3953-
struct ggml_tensor * a,
3954-
struct ggml_tensor * b,
3955-
float eps,
3956-
bool inplace) {
3957-
3958-
if (!b) {
3959-
return ggml_rms_norm_impl(ctx, a, eps, inplace);
3960-
}
3961-
3962-
if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
3963-
struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace);
3964-
result = ggml_mul_impl(ctx, result, b, inplace);
3965-
return result;
3966-
}
3967-
3968-
// bool is_node = false;
3969-
3970-
// if (!inplace && (a->grad)) {
3971-
// is_node = true;
3972-
// }
3973-
3974-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3975-
3976-
ggml_set_op_params(result, &eps, sizeof(eps));
3977-
3978-
result->op = GGML_OP_FUSED_RMS_NORM;
3979-
// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3980-
result->src[0] = a;
3981-
result->src[1] = b;
3982-
3983-
return result;
3984-
}
3985-
3986-
struct ggml_tensor * ggml_fused_rms_norm(
3987-
struct ggml_context * ctx,
3988-
struct ggml_tensor * a,
3989-
struct ggml_tensor * b,
3990-
float eps) {
3991-
return ggml_fused_rms_norm_impl(ctx, a, b, eps, false);
3992-
}
3993-
3994-
struct ggml_tensor * ggml_fused_rms_norm_inplace(
3995-
struct ggml_context * ctx,
3996-
struct ggml_tensor * a,
3997-
struct ggml_tensor * b,
3998-
float eps) {
3999-
return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
4000-
}
4001-
40023949
// ggml_rms_norm_back
40033950

40043951
struct ggml_tensor * ggml_rms_norm_back(
@@ -7054,11 +7001,6 @@ static void ggml_compute_backward(
70547001
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
70557002
}
70567003
} break;
7057-
// case GGML_OP_FUSED_RMS_NORM:
7058-
// {
7059-
// GGML_ABORT("fatal error"); // TODO: not implemented
7060-
// }
7061-
// } break;
70627004
// case GGML_OP_FUSED_MUL_UNARY:
70637005
// {
70647006
// GGML_ABORT("fatal error"); // TODO: implement

src/llama-graph.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,16 +611,6 @@ ggml_tensor * llm_graph_context::build_norm(
611611
ggml_tensor * mb,
612612
llm_norm_type type,
613613
int il) const {
614-
615-
if (type == LLM_NORM_RMS && mw) {
616-
cur = ggml_fused_rms_norm(ctx0, cur, mw, hparams.f_norm_rms_eps);
617-
if (mb) {
618-
cb(cur, "fused_norm", il);
619-
cur = ggml_add(ctx0, cur, mb);
620-
}
621-
return cur;
622-
}
623-
624614
switch (type) {
625615
case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
626616
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;

0 commit comments

Comments
 (0)