Skip to content

Commit 60739b8

Browse files
committed
Change to combined QKV for CogVLM LLM portion
1 parent 40b07b2 commit 60739b8

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/llama-model.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18098,21 +18098,17 @@ struct llm_build_cogvlm : public llm_graph_context {
1809818098

1809918099
for (int il = 0; il < n_layer; ++il) {
1810018100
// get either the text or image weight tensors
18101-
ggml_tensor * wq, * wk, * wv, * wo;
18101+
ggml_tensor * wqkv, * wo;
1810218102
ggml_tensor * ffn_gate, * ffn_down, * ffn_up;
1810318103

1810418104
if (is_text) {
18105-
wq = model.layers[il].wq;
18106-
wk = model.layers[il].wk;
18107-
wv = model.layers[il].wv;
18105+
wqkv = model.layers[il].wqkv;
1810818106
wo = model.layers[il].wo;
1810918107
ffn_gate = model.layers[il].ffn_gate;
1811018108
ffn_down = model.layers[il].ffn_down;
1811118109
ffn_up = model.layers[il].ffn_up;
1811218110
} else {
18113-
wq = model.layers[il].visexp_attn_wq;
18114-
wk = model.layers[il].visexp_attn_wk;
18115-
wv = model.layers[il].visexp_attn_wv;
18111+
wqkv = model.layers[il].visexp_attn_wqkv;
1811618112
wo = model.layers[il].visexp_attn_wo;
1811718113
ffn_gate = model.layers[il].visexp_ffn_gate;
1811818114
ffn_down = model.layers[il].visexp_ffn_down;
@@ -18124,14 +18120,19 @@ struct llm_build_cogvlm : public llm_graph_context {
1812418120

1812518121
// build self attention
1812618122
{
18127-
ggml_tensor * Qcur = build_lora_mm(wq, cur);
18128-
cb(Qcur, "Qcur", il);
18129-
18130-
ggml_tensor * Kcur = build_lora_mm(wk, cur);
18131-
cb(Kcur, "Kcur", il);
18132-
18133-
ggml_tensor * Vcur = build_lora_mm(wv, cur);
18134-
cb(Vcur, "Vcur", il);
18123+
ggml_tensor * qkv = build_lora_mm(wqkv, cur);
18124+
cb(qkv, "qkv", il);
18125+
18126+
// split qkv into Q, K, V along the first dimension
18127+
ggml_tensor * Qcur = ggml_view_2d(ctx0, qkv,
18128+
n_embd, n_tokens,
18129+
ggml_row_size(qkv->type, n_embd), 0);
18130+
ggml_tensor * Kcur = ggml_view_2d(ctx0, qkv,
18131+
n_embd, n_tokens,
18132+
ggml_row_size(qkv->type, n_embd), n_embd * ggml_element_size(qkv));
18133+
ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv,
18134+
n_embd, n_tokens,
18135+
ggml_row_size(qkv->type, n_embd), 2 * n_embd * ggml_element_size(qkv));
1813518136

1813618137
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
1813718138
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);

0 commit comments

Comments
 (0)