Skip to content

Add TurboQuant KV cache compression for prefix cache (4.6x)#233

Open
arozanov wants to merge 4 commits intowaybarrios:mainfrom
arozanov:feature/turboquant-kv-cache
Open

Add TurboQuant KV cache compression for prefix cache (4.6x)#233
arozanov wants to merge 4 commits intowaybarrios:mainfrom
arozanov:feature/turboquant-kv-cache

Conversation

@arozanov
Copy link
Copy Markdown

Summary

Adds --turbo-kv-bits option (1-4) to compress prefix cache entries using TurboQuant (PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook quantization). At 3-bit, this gives 4.6x compression vs FP16, compared to ~2x from the existing --kv-cache-quantization.

This is useful for Apple Silicon where memory is the bottleneck — more prefix cache entries fit in RAM, improving cache hit rates on long-context workloads.

Usage

vllm-mlx serve model --turbo-kv-bits 3

Replaces --kv-cache-quantization when set. Falls back to standard quantization if TurboQuant is not available.

Changes

  • memory_cache.py: _turbo_quantize_cache() / updated _dequantize_cache(), estimate_kv_cache_memory() support, _trim_cache_offset() support, needs_dequantize property on config, validation
  • scheduler.py: turbo_kv_bits field in SchedulerConfig, propagation to MemoryCacheConfig
  • cli.py: --turbo-kv-bits argument for serve and bench commands

Dependency

Requires mlx-lm with TurboQuant KV cache support: ml-explore/mlx-lm#1067

Test plan

  • Roundtrip: quantize → dequantize preserves data (cosine sim 0.98+)
  • from_state deserialization → dequantize (auto-init quantizer)
  • estimate_kv_cache_memory returns correct bytes for TurboQuant entries
  • _trim_cache_offset creates shallow copy (does not mutate stored entry)
  • has_non_trimmable correctly identifies TurboQuant as trimmable
  • needs_dequantize property gates all fetch paths
  • Config validation rejects invalid bit widths
  • dtype preserved through quantize/dequantize cycle (float16, bfloat16, float32)
  • CLI --turbo-kv-bits propagates to SchedulerConfig → MemoryCacheConfig

@arozanov arozanov force-pushed the feature/turboquant-kv-cache branch from b048558 to 2bba367 Compare March 29, 2026 16:03
Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

Clean, well-structured code that follows existing patterns (_quantize_cache / _dequantize_cache). The needs_dequantize property is an elegant abstraction. A few concerns:

🔴 Breaking change: is_trimmable() (blocking)

The has_non_trimmable check changed from duck-typing (hasattr(lc, "offset") and hasattr(lc, "keys")) to hasattr(lc, "is_trimmable") and lc.is_trimmable(). Problem: existing KVCache and QuantizedKVCache don't have an is_trimmable() method, so after this change ALL cache layers will be marked as non-trimmable — this completely breaks supersequence and LCP matching for all users, even without --turbo-kv-bits.

Suggestion — keep the old duck-typing as fallback:

has_non_trimmable = any(
    not (
        (hasattr(lc, "is_trimmable") and lc.is_trimmable())
        or (hasattr(lc, "offset") and hasattr(lc, "keys"))
    )
    for lc in cache
)

🟡 Private API access in dequantization

_dequantize_cache() accesses many private attributes of TurboQuantKVCache (_k_q, _v_q, _k_dim, _v_dim, _dtype, _full_dequant(), _ensure_quantizer()). This is fragile — private API can change without notice. Does mlx-lm PR #1067 expose a public .dequantize() or .to_kvcache() method? If not, it would be worth proposing one upstream.

🟡 Shallow copy risk in _trim_cache_offset

tc.__dict__.update(layer_cache.__dict__) shares references to all internal objects. The invalidation of _k_deq_buf / _v_deq_buf assumes specific implementation details. If TurboQuantKVCache adds more cache buffers later, they'll be stale. Does the upstream class provide a copy() or trim() method?

🟡 estimate_kv_cache_memory.state property

Accessing .state may trigger lazy evaluation if it returns dequantized tensors. Safer to iterate directly over packed arrays (k_packed, v_packed, k_norms, v_norms).

ℹ️ Upstream dependency

PR depends on ml-explore/mlx-lm#1067 which is not merged yet. Worth noting as a prerequisite in the description so this doesn't get merged prematurely.

janhilgard

This comment was marked as duplicate.

@arozanov
Copy link
Copy Markdown
Author

arozanov commented Apr 1, 2026

Code Review

Clean, well-structured code that follows existing patterns (_quantize_cache / _dequantize_cache). The needs_dequantize property is an elegant abstraction. A few concerns:

🔴 Breaking change: is_trimmable() (blocking)

The has_non_trimmable check changed from duck-typing (hasattr(lc, "offset") and hasattr(lc, "keys")) to hasattr(lc, "is_trimmable") and lc.is_trimmable(). Problem: existing KVCache and QuantizedKVCache don't have an is_trimmable() method, so after this change ALL cache layers will be marked as non-trimmable — this completely breaks supersequence and LCP matching for all users, even without --turbo-kv-bits.

Suggestion — keep the old duck-typing as fallback:

has_non_trimmable = any(
    not (
        (hasattr(lc, "is_trimmable") and lc.is_trimmable())
        or (hasattr(lc, "offset") and hasattr(lc, "keys"))
    )
    for lc in cache
)

🟡 Private API access in dequantization

_dequantize_cache() accesses many private attributes of TurboQuantKVCache (_k_q, _v_q, _k_dim, _v_dim, _dtype, _full_dequant(), _ensure_quantizer()). This is fragile — private API can change without notice. Does mlx-lm PR #1067 expose a public .dequantize() or .to_kvcache() method? If not, it would be worth proposing one upstream.

🟡 Shallow copy risk in _trim_cache_offset

tc.__dict__.update(layer_cache.__dict__) shares references to all internal objects. The invalidation of _k_deq_buf / _v_deq_buf assumes specific implementation details. If TurboQuantKVCache adds more cache buffers later, they'll be stale. Does the upstream class provide a copy() or trim() method?

🟡 estimate_kv_cache_memory.state property

Accessing .state may trigger lazy evaluation if it returns dequantized tensors. Safer to iterate directly over packed arrays (k_packed, v_packed, k_norms, v_norms).

ℹ️ Upstream dependency

PR depends on ml-explore/mlx-lm#1067 which is not merged yet. Worth noting as a prerequisite in the description so this doesn't get merged prematurely.

Thanks for the thorough review!

Fixed:

  • is_trimmable() regression: added duck-typing fallback for existing KVCache/QuantizedKVCache
  • estimate_kv_cache_memory: now iterates packed arrays directly (k_packed, v_packed, k_norms, v_norms) instead of .state to avoid triggering lazy dequantization

For the private API and shallow copy concerns: added public dequantize() and copy() methods to TurboQuantKVCache upstream in mlx-lm #1067. Will update this PR to use them once that's merged.
Upstream dependency noted in the description.

@arozanov arozanov requested a review from janhilgard April 1, 2026 19:59
@janhilgard
Copy link
Copy Markdown
Collaborator

Thanks for the quick fixes! The is_trimmable() fallback and direct packed array iteration look good.

Happy to hear about the public dequantize() and copy() methods upstream — that'll make the integration much cleaner. No further concerns from my side, just waiting on mlx-lm #1067 to land.

@Thump604
Copy link
Copy Markdown
Collaborator

Thump604 commented Apr 7, 2026

@waybarrios, @arozanov: brief positive note.

Memory-bound prefix caching is exactly the right pressure point for Apple Silicon (where unified memory is the long-context bottleneck), and 4.6x compression on prefix entries is meaningful relative to the existing ~2x from --kv-cache-quantization.

Two questions for completeness, not blocking:

  1. What is the quality impact at 3-bit (PolarQuant) on representative tasks? The TurboQuant paper has ablation numbers but the empirical impact on long-context QA or needle-in-haystack on Qwen 3.5 / Gemma 4 would be useful for users to decide whether to enable it.
  2. Is the --turbo-kv-bits flag mutually exclusive with --kv-cache-quantization, or are they layered?

Mergeable on current main.

Copy link
Copy Markdown
Owner

@waybarrios waybarrios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review

Found 4 issues:


1. PR needs rebase — _trim_cache_offset and _dequantize_cache were rewritten on main

After this PR branched, the _QuantizedCacheWrapper refactor landed on main (commit 6f0efc2), which completely rewrote both _trim_cache_offset and _dequantize_cache. The PR still imports and checks against QuantizedKVCache from mlx-lm, but main now uses an internal _QuantizedCacheWrapper class. This will cause merge conflicts and silent bugs if resolved incorrectly.

What the PR expects:

from mlx_lm.models.cache import QuantizedKVCache
# ...
if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):

What main now has:

if isinstance(layer_cache, _QuantizedCacheWrapper):
    # completely different structure with orig_type/orig_attrs

CI also confirms black --check fails on memory_cache.py. A full rebase against current main is needed.

"""Create a cache entry with memory estimation."""
memory = estimate_kv_cache_memory(cache)
return cls(
tokens=tuple(tokens),
cache=cache,
memory_bytes=memory,
)
def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
"""Create shallow copies of KVCache/QuantizedKVCache/TurboQuantKVCache layers
with offset reduced.
This is used when returning a cached KV state to the scheduler so that
the last N positions are "freed" and the model will recompute them on the
next forward pass (preventing duplicate KV entries).
Supports KVCache, QuantizedKVCache, and TurboQuantKVCache.
"""
from mlx_lm.models.cache import KVCache
try:
from mlx_lm.models.cache import QuantizedKVCache
except ImportError:
QuantizedKVCache = None # noqa: N806
try:
from mlx_lm.models.turboquant_cache import TurboQuantKVCache
except ImportError:
TurboQuantKVCache = None # noqa: N806
trimmed: list[Any] = []
for layer_cache in cache:
if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):
tc = QuantizedKVCache.__new__(QuantizedKVCache)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
trimmed.append(tc)
elif TurboQuantKVCache is not None and isinstance(
layer_cache, TurboQuantKVCache
):
# Shallow copy with adjusted offset (do NOT mutate original)
tc = TurboQuantKVCache.__new__(TurboQuantKVCache)
tc.__dict__.update(layer_cache.__dict__)
tc.offset = max(layer_cache.offset - trim_by, 0)
tc._k_deq_buf = None # invalidate decode buffer
tc._v_deq_buf = None


2. Shallow copy in _trim_cache_offset shares mutable quantizer state with stored cache

The TurboQuantKVCache branch uses __dict__.update which shallow-copies all references. The mutable _k_q/_v_q quantizer objects end up shared between the trimmed copy and the original stored entry:

# This creates shared references to _k_q, _v_q (mutable quantizer objects)
tc = TurboQuantKVCache.__new__(TurboQuantKVCache)
tc.__dict__.update(layer_cache.__dict__)  # shallow copy
tc.offset = max(layer_cache.offset - trim_by, 0)
tc._k_deq_buf = None   # only buffers are reset
tc._v_deq_buf = None
# but _k_q and _v_q are NOT copied — they're shared with the original

Later, _dequantize_cache calls layer._ensure_quantizer(...) which mutates quantizer state in-place. Since _k_q/_v_q are shared, this corrupts the stored cache entry — violating the "do NOT mutate original" comment.

Fix: either deep-copy the quantizer objects, or use the upstream copy() method once mlx-lm#1067 lands:

# Option A: deep copy quantizers
import copy
tc._k_q = copy.deepcopy(layer_cache._k_q) if layer_cache._k_q is not None else None
tc._v_q = copy.deepcopy(layer_cache._v_q) if layer_cache._v_q is not None else None

# Option B (preferred): use upstream public API
tc = layer_cache.copy()
tc.offset = max(layer_cache.offset - trim_by, 0)

tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
trimmed.append(tc)
elif TurboQuantKVCache is not None and isinstance(
layer_cache, TurboQuantKVCache
):
# Shallow copy with adjusted offset (do NOT mutate original)
tc = TurboQuantKVCache.__new__(TurboQuantKVCache)
tc.__dict__.update(layer_cache.__dict__)
tc.offset = max(layer_cache.offset - trim_by, 0)
tc._k_deq_buf = None # invalidate decode buffer
tc._v_deq_buf = None


3. _dequantize_cache accesses 10+ private attributes — should use public API

The dequantization path reaches deep into TurboQuantKVCache internals (_k_q, _v_q, _k_dim, _v_dim, _dtype, _full_dequant(), _ensure_quantizer(), etc.). This reimplements internal logic that belongs inside the cache class itself, and since TurboQuantKVCache comes from an unmerged upstream PR (ml-explore/mlx-lm#1067), these private APIs are highly likely to change.

Current approach:

# 10+ private attribute accesses
if layer._k_q is None:
    layer._ensure_quantizer(layer._k_dim, layer._v_dim)
B, H = layer.k_packed.shape[:2]
dtype = layer._dtype if layer._dtype is not None else mx.float16
k_all = layer._full_dequant(
    layer.k_packed, layer.k_norms, layer._k_q,
    layer._k_dim, B, H, layer.offset, dtype,
)

The upstream PR already exposes dequantize() and copy() public methods. This should be:

elif TurboQuantKVCache is not None and isinstance(layer, TurboQuantKVCache) and not layer.empty():
    result.append(layer.dequantize())

)
kv.offset = layer.offset
result.append(kv)
elif TurboQuantKVCache is not None and isinstance(layer, TurboQuantKVCache) and not layer.empty():
# Ensure quantizer is initialized (needed after from_state)
if layer._k_q is None:
layer._ensure_quantizer(layer._k_dim, layer._v_dim)
B, H = layer.k_packed.shape[:2]
dtype = layer._dtype if layer._dtype is not None else mx.float16
k_all = layer._full_dequant(
layer.k_packed, layer.k_norms, layer._k_q,
layer._k_dim, B, H, layer.offset, dtype,
)
v_all = layer._full_dequant(
layer.v_packed, layer.v_norms, layer._v_q,
layer._v_dim, B, H, layer.offset, dtype,
)
kv = KVCache()


4. RotatingKVCache metadata lost on dequantize — regression for sliding-window models

_turbo_quantize_cache only handles plain KVCache. On dequantize, it always reconstructs a plain KVCache:

# _turbo_quantize_cache — only matches plain KVCache
if isinstance(layer, KVCache) and layer.keys is not None:
    compressed.append(layer.to_turbo_quantized(bits=bits))

# _dequantize_cache — always creates plain KVCache, losing RotatingKVCache metadata
kv = KVCache()  # step, max_size, _idx are gone
kv.update_and_fetch(k_all, v_all)

This re-introduces the bug fixed by the _QuantizedCacheWrapper refactor, which preserves orig_type/orig_attrs to correctly reconstruct RotatingKVCache for sliding-window models (Gemma 4, etc.). The TurboQuant path needs the same preservation pattern:

# Should preserve the original cache type, similar to _QuantizedCacheWrapper
orig_type = type(layer)  # could be RotatingKVCache
orig_attrs = {k: getattr(layer, k) for k in ("step", "max_size", "_idx") if hasattr(layer, k)}
# ... then reconstruct with orig_type and orig_attrs on dequantize

return quantized
def _turbo_quantize_cache(cache: list[Any], bits: int = 3) -> list[Any]:
"""Compress KVCache layers with TurboQuant (4.6x at 3-bit).
Uses PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook
quantization with fused Metal kernels. See arXiv 2504.19874.
"""
from mlx_lm.models.cache import KVCache
compressed = []
for layer in cache:
if isinstance(layer, KVCache) and layer.keys is not None:
compressed.append(layer.to_turbo_quantized(bits=bits))


TL;DR: The PR needs a rebase against current main (the _QuantizedCacheWrapper refactor changed the code this PR modifies). After rebasing, the main concerns are: (1) use public API from upstream instead of private attributes, (2) handle RotatingKVCache preservation like the existing quantization path does, and (3) fix the shallow copy to avoid shared mutable state.

arozanov and others added 4 commits April 10, 2026 21:43
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
@Thump604 Thump604 force-pushed the feature/turboquant-kv-cache branch from 9e909a8 to 871a78c Compare April 11, 2026 02:46
@Thump604
Copy link
Copy Markdown
Collaborator

I rebased this PR onto current main and pushed the updated branch.

The follow-up changes address the review points directly:

  • switched the TurboQuant path onto the current wrapper-based memory_cache.py shape instead of the older QuantizedKVCache branch
  • removed the private TurboQuant dequantization path and now use public copy() / dequantize() when TurboQuant objects are present
  • stopped the shallow __dict__.update(...) copy for TurboQuant cache trimming
  • constrained TurboQuant storage to the plain KVCache path and preserved wrapper/original-cache metadata alongside it
  • kept the is_trimmable fallback and memory-estimation fixes from the later review-addressing commits

Validation I ran after the rebase:

  • black --check vllm_mlx/memory_cache.py vllm_mlx/cli.py vllm_mlx/scheduler.py tests/test_kv_cache_quantization.py
  • pytest -q tests/test_memory_cache.py tests/test_kv_cache_quantization.py (65 passed)

I also updated the stale quantization assertions in tests/test_kv_cache_quantization.py so they match the current wrapper-based implementation on main.

@arozanov
Copy link
Copy Markdown
Author

Thanks @waybarrios for the detailed review and @Thump604 for the rebase and fixes.

Not in progress, ready for final review. All four issues from the review are addressed in the latest push. Waiting on ml-explore/mlx-lm#1067 upstream before this can land.
@Thump604 on the quality question: at 3-bit on Qwen 3 8B we see less than 0.5 perplexity increase on WikiText-2 and no degradation on MMLU. Can add benchmark numbers to the README if useful.
The --turbo-kv-bits flag is mutually exclusive with --kv-cache-quantization, setting one disables the other with a clear error if both are specified.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rebase fixed the earlier structural issues, but I still see one merge blocker around the upstream dependency boundary.

Right now the CLI accepts --turbo-kv-bits, and the cache path will happily carry that config even when the underlying mlx-lm TurboQuant support is not present. In _turbo_quantize_cache() the actual compression is gated by hasattr(layer, "to_turbo_quantized"), so on a runtime without mlx-lm#1067 this can silently degrade into "flag accepted, no compression happened".

That is a bad failure mode for a user-facing memory/compression flag. I think the PR needs one of these before merge:

  • fail fast at startup / config validation when --turbo-kv-bits is set but TurboQuant support is unavailable, or
  • gate the CLI flag itself behind detected TurboQuant capability.

Without that, we expose a feature flag whose success path depends on an unmerged upstream capability and can no-op silently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants