@@ -16191,7 +16191,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
16191
16191
{
16192
16192
// PLaMo-2 uses combined QKV tensor
16193
16193
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
16194
- cb(qkv, "qkv ", il);
16194
+ cb(qkv, "wqkv ", il);
16195
16195
16196
16196
// split QKV tensor into Q, K, V
16197
16197
const int64_t n_embd_head_q = hparams.n_embd_head_k;
@@ -16231,7 +16231,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
16231
16231
ext_factor, attn_factor, beta_fast, beta_slow
16232
16232
);
16233
16233
16234
- cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
16234
+ cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)) , il);
16235
16235
}
16236
16236
16237
16237
cb(cur, "attn_out", il);
@@ -16306,8 +16306,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
16306
16306
ggml_build_forward_expand(gf,
16307
16307
ggml_cpy(ctx0, last_conv,
16308
16308
ggml_view_1d(ctx0, conv_states_all,
16309
- (d_conv - 1)*(d_inner)*(n_seqs),
16310
- kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
16309
+ (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
16310
+ kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
16311
+ cb(conv_states_all, "mamba_conv1d_state", il);
16311
16312
16312
16313
// 1D convolution
16313
16314
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
@@ -16370,9 +16371,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
16370
16371
// store last states
16371
16372
ggml_build_forward_expand(gf,
16372
16373
ggml_cpy(ctx0,
16373
- ggml_view_1d(ctx0, y_ssm, d_state*d_inner* n_seqs, x->nb[3]*x->ne[3] ),
16374
- ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner* n_seqs,
16375
- kv_head*d_state*d_inner*ggml_element_size (ssm_states_all))) );
16374
+ ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state* n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm) ),
16375
+ ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state* n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
16376
+ cb (ssm_states_all, "mamba_ssm_states", il );
16376
16377
16377
16378
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
16378
16379
cb(y, "mamba_y_view", il);
0 commit comments