Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* |
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes |
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/check_env/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def check_dtype(self, config):
if not is_bf16_supported(device_type):
logger.warning('Device does not support bfloat16.')
except Exception as e:
message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.')
message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
self.log_and_exit(e, 'Model', message=message)

try:
model_config.check_env_func(device_type)
except Exception as e:
message = (f'Checking failed with error {e}.')
self.log_and_exit(e, 'Model', message=message)

def check(self):
Expand Down
17 changes: 15 additions & 2 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from dataclasses import dataclass
from typing import Any, Dict, List, Literal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Literal, Tuple

import torch

Expand Down Expand Up @@ -86,6 +86,8 @@ class CacheConfig:
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
device_type: str = 'cuda'
num_state_caches: int = None
states_shapes: List[Tuple] = field(default_factory=list)

# For PD Disaggregation
role: EngineRole = EngineRole.Hybrid
Expand Down Expand Up @@ -183,6 +185,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
_override_hf_config(hf_config, k, v)


def _default_check_env(device: str):
pass


@dataclass
class ModelConfig:
"""Config of model."""
Expand All @@ -208,6 +214,13 @@ class ModelConfig:
dllm_mask_token: int = 0
dllm_block_length: int = None

# added for qwen3_next
# could used for any SSM model.
states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)

# check env for model-device combination
check_env_func: Callable = _default_check_env

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/configurations/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
eos_token_id=hf_config.eos_token_id,
sliding_window=sliding_window,
head_dim=head_dim,
k_head_dim=head_dim,
v_head_dim=head_dim,
vocab_size=hf_config.vocab_size,
llm_config=hf_config,
)
58 changes: 58 additions & 0 deletions lmdeploy/pytorch/configurations/qwen3_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


def _check_env_qwen3_next(device: str):
"""Check env for qwen3 next."""
if device != 'cuda':
return

# check cuda
try:
import causal_conv1d # noqa: F401
except ImportError:
raise ImportError('Qwen3-Next cuda support requires https://github.com/Dao-AILab/causal-conv1d.')

try:
import fla # noqa: F401
except ImportError:
raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')


class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'qwen3_next'

@classmethod
def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
"""build."""
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)

# update num layers
num_layers = cfg.num_layers
num_full_layers = num_layers // hf_config.full_attention_interval
num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)
cfg.num_layers = num_full_layers

# set state shapes
head_k_dim = hf_config.linear_key_head_dim
head_v_dim = hf_config.linear_value_head_dim
num_v_heads = hf_config.linear_num_value_heads // tp
num_k_heads = hf_config.linear_num_key_heads // tp
key_dim = head_k_dim * num_k_heads
value_dim = head_v_dim * num_v_heads
conv_dim = key_dim * 2 + value_dim
conv_kernel_size = hf_config.linear_conv_kernel_dim

conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
dtype = torch.bfloat16
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
cfg.check_env_func = _check_env_qwen3_next
return cfg
93 changes: 93 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import json
import math
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple

import torch
Expand All @@ -20,6 +22,23 @@
logger = get_logger('lmdeploy')


def round_up(x: int, alignment: int) -> int:
"""Round up x to the nearest multiple of alignment."""
return ((x + alignment - 1) // alignment) * alignment


@dataclass
class CacheDesc:
"""Cache description."""
shape: List[int]
dtype: torch.dtype
alignment: int = 256

def __post_init__(self):
self.size = math.prod(self.shape) * self.dtype.itemsize
self.aligned_size = round_up(self.size, self.alignment)


class CacheEngine:
"""Host and Device memory maintainer.

Expand Down Expand Up @@ -384,3 +403,77 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote
))

""" Metheds for PD Disaggregation End. """


class StateCacheEngine:
"""Cache engine for state cache."""

def __init__(self, cache_config: CacheConfig):
self.cache_config = cache_config
self.mem_pool, self._state_caches = self.allocate_caches(num_caches=cache_config.num_state_caches,
state_shapes=cache_config.states_shapes,
device='cuda')

@staticmethod
def allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device):
"""Allocate cache implement."""

if len(state_shapes) == 0 or num_caches == 0:
return torch.empty((0, 0), dtype=torch.uint8, device=device), []

cache_descs = [CacheDesc(shape, dtype) for shape, dtype in state_shapes]

# get mempool size
mem_pool_size = 0
for desc in cache_descs:
mem_pool_size += desc.aligned_size

# create pool
mem_pool = torch.zeros((num_caches, mem_pool_size), dtype=torch.uint8, device=device)

# slice caches
caches = []
remain_pool = mem_pool
for desc in cache_descs:
cache = remain_pool[:, :desc.size].view(desc.dtype).view((num_caches, *desc.shape))
remain_pool = remain_pool[:, desc.aligned_size:]
caches.append(cache)
return mem_pool, caches

@staticmethod
def get_cache_state_size(state_shapes: List[Tuple[Tuple[int], torch.dtype]]) -> int:
"""Get the required cache size of the state cache.

Args:
state_shapes (List[Tuple[Tuple[int], torch.dtype]]): The shapes and dtypes of the states.

Return:
int: Required memory size in bytes.
"""
mem_pool, _ = StateCacheEngine.allocate_caches(num_caches=1, state_shapes=state_shapes, device='meta')
return mem_pool.numel() * mem_pool.element_size()

@property
def state_caches(self):
"""State caches."""
return self._state_caches

def init_caches(self, idx: torch.Tensor, mask: torch.Tensor):
"""Initialize state caches.

idx: indices of caches to be initialized.
mask: mask to indicate which idx to be initialized.
"""
if idx is None:
return

if len(self._state_caches) <= 0:
return

num_caches = self.cache_config.num_state_caches

# get mask of all caches so we can perform inplace mask fill
cache_masks = torch.zeros((num_caches, ), dtype=torch.bool, device=idx.device)
cache_masks.index_copy_(0, idx, mask)
reshaped_mask = cache_masks.view((-1, ) + (1, ) * (self.mem_pool.dim() - 1))
self.mem_pool.masked_fill_(reshaped_mask, 0)
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,11 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs)
model_inputs.vision_inputs = vision_model_inputs

# ssm
if len(self.cache_config.states_shapes) > 0:
state_offsets = torch.tensor([msg.logical_state for msg in messages])
model_inputs.state_offsets = state_offsets

return model_inputs

def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor,
Expand Down
36 changes: 34 additions & 2 deletions lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_
# estimate runtime mem size
runtime_cache_size = int((max_prefill_token_num + max_batches * 2) * vocal_size * 2)
num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count
if int(num_available) // cache_block_size >= 16:
if cache_block_size == 0 or int(num_available) // cache_block_size >= 16:
break
max_prefill_token_num = max_prefill_token_num // 2
return runtime_cache_size, max_prefill_token_num
Expand All @@ -153,16 +153,48 @@ def _adjust_block_size(self):
f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.' # noqa
)

def _get_state_cache_mem(self):
"""Get state cache mem usage."""
cache_config = self.cache_config
if len(cache_config.states_shapes) == 0:
return 0

from lmdeploy.pytorch.engine.cache_engine import StateCacheEngine

num_state_caches = cache_config.num_state_caches
if num_state_caches is None:
# add more caches for eviction
# TODO: Share memory between state cache and pageable cache
num_state_caches = int(cache_config.max_batches + 8)
cache_config.num_state_caches = num_state_caches

mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes)
mems *= num_state_caches

if cache_config.enable_prefix_caching:
cache_config.enable_prefix_caching = False
logger.warning('Prefix caching has not been support for state space model.')

return mems

def update_configs(self):
"""Update cache config."""
self._adjust_block_size()
cache_config = self.cache_config
model_config = self.model_config
cache_config.states_shapes = model_config.states_shapes

# get free mems
free_mems = self.gather_free_mem()
free_mem = min(free_mems)
logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')
vocal_size = self.model_config.vocab_size

# get state cache size
state_cache_mem = self._get_state_cache_mem()
free_mem = free_mem - state_cache_mem
assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.'

vocal_size = self.model_config.vocab_size
tp = self.dist_config.attn_config.tp
cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp,
cache_config.quant_policy)
Expand Down
Loading