@@ -421,11 +421,38 @@ struct llm_graph_params {
421
421
// TODO: temporary
422
422
llm_graph_result_i * res;
423
423
424
- bool is_same (const llm_graph_params & other) const {
424
+ // return true if the "other" params would result in a graph with the same topology as with the current params
425
+ // having the same topology allows us to reuse the graph in some cases
426
+ bool allow_reuse (const llm_graph_params & other) const {
427
+ // first check the ubatch
428
+ bool can_reuse_ubatch =
429
+ ubatch.equal_seqs == other.ubatch .equal_seqs &&
430
+ ubatch.n_tokens == other.ubatch .n_tokens &&
431
+ ubatch.n_seq_tokens == other.ubatch .n_seq_tokens &&
432
+ ubatch.n_seqs == other.ubatch .n_seqs &&
433
+ ubatch.n_seqs_unq == other.ubatch .n_seqs_unq &&
434
+ (
435
+ (!ubatch.token && !other.ubatch .token ) ||
436
+ (!ubatch.embd && !other.ubatch .embd )
437
+ );
438
+
439
+ // TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
440
+ // been freed by this point. find a way to fix this
441
+ // for (uint32_t s = 0; s < n_seqs_unq; ++s) {
442
+ // can_reuse_ubatch &= seq_id_unq[s] == other.seq_id_unq[s];
443
+ // }
444
+
445
+ // for now conservatively disallow, until the issue above is resolved
446
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
447
+ can_reuse_ubatch = can_reuse_ubatch && !ubatch.equal_seqs ;
448
+
449
+ if (!can_reuse_ubatch) {
450
+ return false ;
451
+ }
452
+
425
453
return
426
- hparams.is_same (other.hparams ) &&
427
- cparams.is_same (other.cparams ) &&
428
- ubatch .is_same (other.ubatch ) &&
454
+ cparams.embeddings == other.cparams .embeddings &&
455
+ cparams.causal_attn == other.cparams .causal_attn &&
429
456
arch == other.arch &&
430
457
gtype == other.gtype &&
431
458
cvec == other.cvec &&
@@ -488,7 +515,7 @@ class llm_graph_result : public llm_graph_result_i {
488
515
// contexts of the input tensors of the graph and we can reuse it for another computation
489
516
// return true if the graph was updated and can be reused
490
517
bool can_reuse (const llm_graph_params & params) override {
491
- if (!this ->params .is_same (params)) {
518
+ if (!this ->params .allow_reuse (params)) {
492
519
return false ;
493
520
}
494
521
0 commit comments