@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746
746
0 );
747
747
}
748
748
749
- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749
+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750
750
const int32_t ikv = map_layer_ids.at (il);
751
751
752
752
auto * k = layers[ikv].k ;
753
753
754
754
const int64_t n_tokens = k_cur->ne [2 ];
755
755
756
+ if (kv_idxs) {
757
+ return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
758
+ }
759
+
756
760
ggml_tensor * k_view = ggml_view_1d (ctx, k,
757
761
n_tokens*hparams.n_embd_k_gqa (il),
758
762
ggml_row_size (k->type , hparams.n_embd_k_gqa (il))*head_cur);
759
763
760
764
return ggml_cpy (ctx, k_cur, k_view);
761
765
}
762
766
763
- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767
+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764
768
const int32_t ikv = map_layer_ids.at (il);
765
769
766
770
auto * v = layers[ikv].v ;
@@ -772,21 +776,48 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772
776
ggml_tensor * v_view = nullptr ;
773
777
774
778
if (!v_trans) {
779
+ if (kv_idxs) {
780
+ return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781
+ }
782
+
775
783
v_view = ggml_view_1d (ctx, v,
776
784
n_tokens*hparams.n_embd_v_gqa (il),
777
785
ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
778
786
} else {
787
+ v_cur = ggml_transpose (ctx, v_cur);
788
+
779
789
// note: the V cache is transposed when not using flash attention
790
+ if (kv_idxs) {
791
+ // the row becomes a single element and we repeat the KV indices d_head times
792
+ // TODO: this seems not very optimal - can we do something better?
793
+ v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794
+
795
+ v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
796
+
797
+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
798
+
799
+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
800
+ }
801
+
780
802
v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
781
803
(v->ne [1 ])*ggml_element_size (v),
782
804
(head_cur)*ggml_element_size (v));
783
-
784
- v_cur = ggml_transpose (ctx, v_cur);
785
805
}
786
806
787
807
return ggml_cpy (ctx, v_cur, v_view);
788
808
}
789
809
810
+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
811
+ const uint32_t n_tokens = ubatch->n_tokens ;
812
+
813
+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
814
+ int64_t * data = (int64_t *) dst->data ;
815
+
816
+ for (int64_t i = 0 ; i < n_tokens; ++i) {
817
+ data[i] = head_cur + i;
818
+ }
819
+ }
820
+
790
821
void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791
822
const uint32_t n_tokens = ubatch->n_tokens ;
792
823
@@ -1789,18 +1820,22 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
1789
1820
return kv->get_v (ctx, il, n_kv);
1790
1821
}
1791
1822
1792
- ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793
- return kv->cpy_k (ctx, k_cur, il, head);
1823
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1824
+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, head);
1794
1825
}
1795
1826
1796
- ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797
- return kv->cpy_v (ctx, v_cur, il, head);
1827
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1828
+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, head);
1798
1829
}
1799
1830
1800
1831
void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
1801
1832
kv->set_input_k_shift (dst);
1802
1833
}
1803
1834
1835
+ void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1836
+ kv->set_input_kv_idxs (dst, ubatch, head);
1837
+ }
1838
+
1804
1839
void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1805
1840
kv->set_input_kq_mask (dst, ubatch, causal_attn);
1806
1841
}
0 commit comments