Skip to content

Commit 1dc9614

Browse files
mitmulCISC
andauthored
llama : fix kq_scale for the attention layers of PLaMo2 (#14892)
* Fix dimensions for expand * Change dimensions to copy states to cache * Fix the default value for plamo2 conversion * Fix scale given to build_attn * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 446595b commit 1dc9614

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,7 +3791,7 @@ def set_gguf_parameters(self):
37913791
self.gguf_writer.add_block_count(block_count)
37923792
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
37933793
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
3794-
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
3794+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
37953795

37963796
# Mamba parameters
37973797
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
@@ -3802,7 +3802,7 @@ def set_gguf_parameters(self):
38023802
self.gguf_writer.add_ssm_group_count(0)
38033803

38043804
# MLP feed forward parameters (for attention layers)
3805-
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
3805+
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 13312))
38063806
self.gguf_writer.add_file_type(self.ftype)
38073807

38083808
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

src/llama-model.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16191,7 +16191,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1619116191
{
1619216192
// PLaMo-2 uses combined QKV tensor
1619316193
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
16194-
cb(qkv, "qkv", il);
16194+
cb(qkv, "wqkv", il);
1619516195

1619616196
// split QKV tensor into Q, K, V
1619716197
const int64_t n_embd_head_q = hparams.n_embd_head_k;
@@ -16231,7 +16231,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1623116231
ext_factor, attn_factor, beta_fast, beta_slow
1623216232
);
1623316233

16234-
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
16234+
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
1623516235
}
1623616236

1623716237
cb(cur, "attn_out", il);
@@ -16306,8 +16306,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1630616306
ggml_build_forward_expand(gf,
1630716307
ggml_cpy(ctx0, last_conv,
1630816308
ggml_view_1d(ctx0, conv_states_all,
16309-
(d_conv - 1)*(d_inner)*(n_seqs),
16310-
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
16309+
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
16310+
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
16311+
cb(conv_states_all, "mamba_conv1d_state", il);
1631116312

1631216313
// 1D convolution
1631316314
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
@@ -16370,9 +16371,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1637016371
// store last states
1637116372
ggml_build_forward_expand(gf,
1637216373
ggml_cpy(ctx0,
16373-
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
16374-
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
16375-
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
16374+
ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
16375+
ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
16376+
cb(ssm_states_all, "mamba_ssm_states", il);
1637616377

1637716378
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
1637816379
cb(y, "mamba_y_view", il);

0 commit comments

Comments
 (0)