@@ -18098,21 +18098,17 @@ struct llm_build_cogvlm : public llm_graph_context {
18098
18098
18099
18099
for (int il = 0; il < n_layer; ++il) {
18100
18100
// get either the text or image weight tensors
18101
- ggml_tensor * wq, * wk, * wv , * wo;
18101
+ ggml_tensor * wqkv , * wo;
18102
18102
ggml_tensor * ffn_gate, * ffn_down, * ffn_up;
18103
18103
18104
18104
if (is_text) {
18105
- wq = model.layers[il].wq;
18106
- wk = model.layers[il].wk;
18107
- wv = model.layers[il].wv;
18105
+ wqkv = model.layers[il].wqkv;
18108
18106
wo = model.layers[il].wo;
18109
18107
ffn_gate = model.layers[il].ffn_gate;
18110
18108
ffn_down = model.layers[il].ffn_down;
18111
18109
ffn_up = model.layers[il].ffn_up;
18112
18110
} else {
18113
- wq = model.layers[il].visexp_attn_wq;
18114
- wk = model.layers[il].visexp_attn_wk;
18115
- wv = model.layers[il].visexp_attn_wv;
18111
+ wqkv = model.layers[il].visexp_attn_wqkv;
18116
18112
wo = model.layers[il].visexp_attn_wo;
18117
18113
ffn_gate = model.layers[il].visexp_ffn_gate;
18118
18114
ffn_down = model.layers[il].visexp_ffn_down;
@@ -18124,14 +18120,19 @@ struct llm_build_cogvlm : public llm_graph_context {
18124
18120
18125
18121
// build self attention
18126
18122
{
18127
- ggml_tensor * Qcur = build_lora_mm(wq, cur);
18128
- cb(Qcur, "Qcur", il);
18129
-
18130
- ggml_tensor * Kcur = build_lora_mm(wk, cur);
18131
- cb(Kcur, "Kcur", il);
18132
-
18133
- ggml_tensor * Vcur = build_lora_mm(wv, cur);
18134
- cb(Vcur, "Vcur", il);
18123
+ ggml_tensor * qkv = build_lora_mm(wqkv, cur);
18124
+ cb(qkv, "qkv", il);
18125
+
18126
+ // split qkv into Q, K, V along the first dimension
18127
+ ggml_tensor * Qcur = ggml_view_2d(ctx0, qkv,
18128
+ n_embd, n_tokens,
18129
+ ggml_row_size(qkv->type, n_embd), 0);
18130
+ ggml_tensor * Kcur = ggml_view_2d(ctx0, qkv,
18131
+ n_embd, n_tokens,
18132
+ ggml_row_size(qkv->type, n_embd), n_embd * ggml_element_size(qkv));
18133
+ ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv,
18134
+ n_embd, n_tokens,
18135
+ ggml_row_size(qkv->type, n_embd), 2 * n_embd * ggml_element_size(qkv));
18135
18136
18136
18137
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
18137
18138
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
0 commit comments