Skip to content

Commit 946f71e

Browse files
authored
llama : fix shapes for bert/mpt q/k norm (#16409)
1 parent 638d330 commit 946f71e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/llama-model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7843,6 +7843,8 @@ struct llm_build_bert : public llm_graph_context {
78437843
}
78447844

78457845
if (model.layers[il].attn_q_norm) {
7846+
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
7847+
78467848
Qcur = build_norm(Qcur,
78477849
model.layers[il].attn_q_norm,
78487850
model.layers[il].attn_q_norm_b,
@@ -7852,6 +7854,8 @@ struct llm_build_bert : public llm_graph_context {
78527854
}
78537855

78547856
if (model.layers[il].attn_k_norm) {
7857+
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
7858+
78557859
Kcur = build_norm(Kcur,
78567860
model.layers[il].attn_k_norm,
78577861
model.layers[il].attn_k_norm_b,
@@ -8234,6 +8238,9 @@ struct llm_build_mpt : public llm_graph_context {
82348238

82358239
// Q/K Layernorm
82368240
if (model.layers[il].attn_q_norm) {
8241+
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
8242+
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
8243+
82378244
Qcur = build_norm(Qcur,
82388245
model.layers[il].attn_q_norm,
82398246
model.layers[il].attn_q_norm_b,

0 commit comments

Comments
 (0)