Skip to content

Commit 939fa72

Browse files
author
Ralf Waldukat
committed
fix: prevent KV cache corruption on SWA/ISWA models (e.g. Gemma-4)
SWA/ISWA KV caches maintain global position maps (g_iswa_pos_max/min) that are only cleared by llama_memory_clear(), not by kv_cache_seq_rm(). When generate() finds a prefix match (e.g. shared BOS token), it calls kv_cache_seq_rm which returns True for ISWA, skipping the full reset. But the stale position maps cause batch allocator inconsistency and llama_decode returned -1 on subsequent prompts. Changes: - Add _has_swa property via llama_model_n_swa() > 0 - reset() now calls llama_memory_clear() unconditionally - generate() bypasses prefix-match optimization for SWA models, forcing full state reset (same path as recurrent models)
1 parent 1cb8b9f commit 939fa72

1 file changed

Lines changed: 39 additions & 1 deletion

File tree

llama_cpp/llama.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ def free_lora_adapter():
553553

554554
self._sampler = None
555555

556+
# Cache model architecture flags to avoid repeated FFI calls
557+
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
558+
self._model.model
559+
) or llama_cpp.llama_model_is_hybrid(self._model.model)
560+
self._has_swa_model = llama_cpp.llama_model_n_swa(
561+
self._model.model
562+
) > 0
563+
556564
@property
557565
def ctx(self) -> llama_cpp.llama_context_p:
558566
return self._ctx.ctx
@@ -580,6 +588,14 @@ def eval_logits(self) -> Deque[List[float]]:
580588
maxlen=self._n_ctx if self._logits_all else 1,
581589
)
582590

591+
@property
592+
def _is_recurrent(self) -> bool:
593+
return self._is_recurrent_model
594+
595+
@property
596+
def _has_swa(self) -> bool:
597+
return self._has_swa_model
598+
583599
def tokenize(
584600
self, text: bytes, add_bos: bool = True, special: bool = False
585601
) -> List[int]:
@@ -638,6 +654,10 @@ def reset(self):
638654
"""Reset the model state."""
639655
self.n_tokens = 0
640656

657+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
658+
if mem is not None:
659+
llama_cpp.llama_memory_clear(mem, True)
660+
641661
def eval(self, tokens: Sequence[int]):
642662
"""Evaluate a list of tokens.
643663
@@ -889,11 +909,29 @@ def generate(
889909
# Check for kv cache prefix match
890910
if reset and self.n_tokens > 0:
891911
longest_prefix = 0
892-
for a, b in zip(self._input_ids, tokens[:-1]):
912+
for a, b in zip(self._input_ids, tokens):
893913
if a == b:
894914
longest_prefix += 1
895915
else:
896916
break
917+
918+
# Recurrent models cannot rewind state; reset if needed
919+
if self._is_recurrent and longest_prefix < self.n_tokens:
920+
longest_prefix = 0
921+
reset = True
922+
if self.verbose:
923+
print(
924+
"Llama.generate: recurrent model requires full state reset",
925+
file=sys.stderr,
926+
)
927+
928+
# SWA/ISWA models (e.g. Gemma-4) have split KV caches whose
929+
# position-tracking maps are only cleared by a full reset.
930+
# Partial seq_rm leaves stale positions and causes decode failure.
931+
if self._has_swa and longest_prefix < self.n_tokens:
932+
longest_prefix = 0
933+
reset = True
934+
897935
if longest_prefix > 0:
898936
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
899937
reset = False

0 commit comments

Comments
 (0)