Skip to content

[WIP][HiCache][HA 2/N] feat: Support HiCache warmup for cold start acceleration#20105

Open
alphabetc1 wants to merge 5 commits intosgl-project:mainfrom
alphabetc1:feat/hicache_global_warmup
Open

[WIP][HiCache][HA 2/N] feat: Support HiCache warmup for cold start acceleration#20105
alphabetc1 wants to merge 5 commits intosgl-project:mainfrom
alphabetc1:feat/hicache_global_warmup

Conversation

@alphabetc1
Copy link
Contributor

Motivation

New sglang instances start with empty KV cache, causing all requests to recompute from scratch. Global warmup enables pre-populating host-level cache from shared storage on startup, prioritizing pinned prefixes, eliminating the cold start "cache stampede" problem.

Modifications

Extended storage write path to persist priority and token_ids metadata via record_warmup_metadata. Added WarmupEntry dataclass and list_warmup_entries to storage ABC with HiCacheFile JSONL manifest reference implementation. Implemented background warmup thread in HiCacheController with queue-based, thread-safe radix tree insertion on the main thread. Added warmup_ratio config. Mirrored changes for Mamba cache.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the hicache Hierarchical Caching for SGLang label Mar 7, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a global warmup mechanism for HiCache, designed to pre-populate the host-level KV cache from shared storage during startup. This enhancement aims to significantly accelerate cold start performance by reducing the need for recomputation, particularly for frequently accessed or 'pinned' prefixes, thereby mitigating the 'cache stampede' problem.

Highlights

  • Warmup Metadata Persistence: The storage write path was extended to save priority and token_ids metadata, enabling the system to remember which cache entries are important for warmup.
  • Warmup Entry Management: A new WarmupEntry dataclass and corresponding list_warmup_entries method were added to the storage abstraction, with a concrete JSONL manifest implementation for HiCacheFile.
  • Background Warmup Thread: A dedicated background thread was implemented in HiCacheController to query storage for warmup entries and load them.
  • Thread-Safe Radix Tree Insertion: Loaded warmup entries are inserted into the radix tree on the main thread using a queue-based, thread-safe mechanism.
  • Configurable Warmup Ratio: A warmup_ratio configuration parameter was introduced to control the proportion of the host KV pool size allocated for warmup.
  • Mamba Cache Integration: The warmup functionality was mirrored and integrated into the Mamba cache implementation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/cache_controller.py
    • The StorageOperation class was updated to include a priority attribute.
    • The write_storage method now accepts a priority argument.
    • The _page_backup method was modified to record warmup metadata, including priority and token_ids, via the storage backend.
    • New methods start_warmup and _warmup_thread_func were added to manage the background warmup process, loading entries from storage and pushing them to a result queue.
  • python/sglang/srt/mem_cache/hi_mamba_radix_cache.py
    • The Queue class was imported from the queue module.
    • A warmup_ratio configuration parameter was added and parsed during initialization.
    • The warmup_from_storage method was introduced to initiate the cache warmup process, creating a queue and starting the cache_controller's warmup thread.
    • The _drain_warmup_queue method was added to process items from the warmup queue and insert them into the radix tree.
    • The _drain_warmup_queue is now called during check_hicache_events to continuously process warmup data.
    • Validation for the warmup_ratio was added to ensure it is a float between 0 and 1.
    • The write_backup_storage method was updated to pass a default priority of 0 to the cache_controller.
  • python/sglang/srt/mem_cache/hicache_storage.py
    • The json, threading, and time modules were imported.
    • A WarmupEntry dataclass was defined to encapsulate warmup metadata.
    • Abstract methods record_warmup_metadata and list_warmup_entries were added to the HiCacheStorage base class.
    • The HiCacheFile implementation now includes a __warmup_manifest__.jsonl file to store warmup metadata.
    • The _append_manifest helper method was added to write warmup records to the manifest file.
    • The record_warmup_metadata method was implemented to append new warmup entries to the manifest.
    • The list_warmup_entries method was implemented to read, deduplicate (by recency), and sort (by priority then timestamp) warmup entries from the manifest.
  • python/sglang/srt/mem_cache/hiradix_cache.py
    • The Queue class was imported from the queue module.
    • A warmup_ratio configuration parameter was added and parsed during initialization.
    • The warmup_from_storage method was introduced to initiate the cache warmup process, creating a queue and starting the cache_controller's warmup thread.
    • The _drain_warmup_queue method was added to process items from the warmup queue and insert them into the radix tree.
    • The _drain_warmup_queue is now called during check_hicache_events to continuously process warmup data.
    • Validation for the warmup_ratio was added to ensure it is a float between 0 and 1.
    • The write_backup_storage method was updated to calculate priority (1 if pinned, 0 otherwise) and pass it to the cache_controller.
Activity
  • No human activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a HiCache warmup mechanism to accelerate cold starts, a valuable feature implemented with well-structured background threads and queues for I/O operations. While the implementation correctly handles cache entry metadata and high-priority loading, critical security vulnerabilities have been identified. A distributed deadlock can occur during warmup if TP ranks have inconsistent manifests, potentially leading to a server hang. The default use of a shared temporary directory (/tmp/hicache) for storage also poses risks of data leakage and cache poisoning. Additionally, the manifest file's lack of size limits or rotation could be exploited for Denial of Service via memory exhaustion. General improvements are also suggested for the scalability of the warmup manifest file and ensuring warmed-up entry priority informs eviction policies.

Comment on lines +1109 to +1113
torch.distributed.all_reduce(
t,
op=torch.distributed.ReduceOp.MIN,
group=tp_group,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The warmup thread performs a torch.distributed.all_reduce operation inside a loop over entries retrieved from the storage backend. Since each rank in a Tensor Parallel (TP) group may have a different set of entries in its local manifest (due to non-deterministic timestamps or previous partial failures), the number of iterations can differ between ranks. This will cause ranks with fewer entries to exit the loop while others are still waiting for an all_reduce, resulting in a permanent deadlock of the warmup process and potentially the entire server during startup.

Comment on lines +251 to +253
self._manifest_path = os.path.join(
self.file_path, f"__warmup_manifest__{self.config_suffix}.jsonl"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The HiCache storage backend defaults to using /tmp/hicache for storing KV cache data and the warmup manifest. On multi-user systems, the /tmp directory is shared. This allows other users on the same machine to read sensitive KV cache data (which may contain PII from prompts) or poison the cache by modifying the manifest file or injecting malicious cache entries. Using a shared temporary directory for sensitive data without restricted permissions is a significant security risk.

Comment on lines +306 to +317
with open(self._manifest_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError:
continue
chain_key = tuple(record["hash_chain"])
# Later entries overwrite earlier (dedup by recency)
seen[chain_key] = record
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The HiCacheFile backend appends records to the JSONL manifest file without size limits or rotation. This can lead to the file growing indefinitely, causing high memory consumption and slow startup times during warmup. Critically, an attacker with write access to the storage directory (defaulting to /tmp/hicache) could exploit this by appending a massive number of entries, leading to excessive memory consumption (OOM) and CPU usage, resulting in a Denial of Service. Consider implementing a mechanism to manage the manifest file's size, such as periodic truncation or rotation, or processing only the most recent portion of the file during warmup to limit memory usage.

Comment on lines +195 to +197
token_ids, hash_chain, host_indices, priority = item
key = RadixKey(token_ids=token_ids)
self._insert_helper_host(self.root_node, key, host_indices, hash_chain)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The priority of the warmup entry is retrieved from the queue but is not used when inserting the entry into the radix tree via _insert_helper_host. This priority information should be propagated to the TreeNode to influence priority-based eviction strategies. Consider modifying _insert_helper_host to accept and utilize this priority value.

Comment on lines +229 to +231
token_ids, hash_chain, host_indices, priority = item
key = RadixKey(token_ids=token_ids)
self._insert_helper_host(self.root_node, key, host_indices, hash_chain)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The priority of the warmup entry is retrieved from the queue but is not used when inserting the entry into the radix tree via _insert_helper_host. This priority information should be propagated to the TreeNode to influence priority-based eviction strategies. Consider modifying _insert_helper_host to accept and utilize this priority value.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 7, 2026
@alphabetc1 alphabetc1 changed the title [HiCache][HA 3/N] feat: Support HiCache warmup for cold start acceleration [HiCache][HA 2/N] feat: Support HiCache warmup for cold start acceleration Mar 8, 2026
@alphabetc1 alphabetc1 changed the title [HiCache][HA 2/N] feat: Support HiCache warmup for cold start acceleration [WIP][HiCache][HA 2/N] feat: Support HiCache warmup for cold start acceleration Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant