diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6419d739bd8a2..b928e9e16ead8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1376,7 +1376,7 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs()); + assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur;