diff --git a/README.md b/README.md
index d204fddff6..de70ceca39 100644
--- a/README.md
+++ b/README.md
@@ -130,6 +130,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
Qwen2-MoE (57BA14B)
Qwen2.5 (0.5B - 32B)
Qwen3, Qwen3-MoE
+ Qwen3-Next(80B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
diff --git a/README_ja.md b/README_ja.md
index 75d05390ad..4c7acc4792 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -117,6 +117,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Qwen2-MoE (57BA14B)
Qwen2.5 (0.5B - 32B)
Qwen3, Qwen3-MoE
+ Qwen3-Next(80B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f6f10a5b42..aa8f6183b4 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -131,6 +131,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Qwen2-MoE (57BA14B)
Qwen2.5 (0.5B - 32B)
Qwen3, Qwen3-MoE
+ Qwen3-Next(80B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index aa28854d8a..6a9b66b73d 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -85,6 +85,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* |
+| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index 8e9e3fef20..7e9cfa648c 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -85,6 +85,7 @@
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes |
+| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py
index deb6c66bfd..88a4f07098 100644
--- a/lmdeploy/pytorch/backends/cuda/graph_runner.py
+++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py
@@ -91,6 +91,16 @@ def __init__(
self.pool = pool
self._graph: torch.cuda.CUDAGraph = None
+ def make_output_buffers(self, output):
+ """Make output buffers."""
+ output_buffers = dict(logits=output)
+ return output_buffers
+
+ def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]):
+ """Slice output."""
+ num_tokens = inputs['input_ids'].size(-1)
+ return output_buffers['logits'][:, :num_tokens]
+
@record_function('capture_cudagraph')
def capture(self, **kwargs):
"""Capture graph."""
@@ -102,7 +112,8 @@ def capture(self, **kwargs):
current_stream = torch.cuda.current_stream()
# warmup
- self.model(**padded_kwargs)
+ warmup_output = self.model(**padded_kwargs)
+ warmup_buffers = self.make_output_buffers(warmup_output)
self._graph = torch.cuda.CUDAGraph()
# unsafe kernel call in other thread might invalid the capture
@@ -110,21 +121,22 @@ def capture(self, **kwargs):
with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):
output = self.model(**padded_kwargs)
- output_buffers = dict(logits=output)
+ output_buffers = self.make_output_buffers(output)
self.meta.output_buffers = output_buffers
+ output = self.slice_output(warmup_buffers, kwargs)
return output
@record_function('forward_cudagraph')
def forward(self, **kwargs):
"""forward."""
- num_tokens = kwargs['input_ids'].size(-1)
assert self._graph is not None
self.model.fill_buffers_cudagraph(self.meta, **kwargs)
context = self.ctx_mgr.current_context()
self.model.update_context_cudagraph(self.meta, context)
self._graph.replay()
- output = self.meta.output_buffers['logits'][:, :num_tokens]
+ output_buffers = self.meta.output_buffers
+ output = self.slice_output(output_buffers, kwargs)
return output
def __del__(self):
@@ -223,12 +235,14 @@ def __call__(self, **kwargs):
pool=self.graph_pool_handle,
model_config=self.model_config,
device=self.device)
- runner.capture(**kwargs)
+ output = runner.capture(**kwargs)
self._runner_map[graph_key] = runner
+ # SSM would update the state in capture(warmup), replay the graph will leads unexpected state update.
+ return output
else:
runner = self._runner_map[graph_key]
- output = runner.forward(**kwargs)
- return output
+ output = runner.forward(**kwargs)
+ return output
@record_function('prepare_inputs_for_generation')
def prepare_inputs_for_generation(
diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py
index d1d096dbfe..3bd7f8cea1 100644
--- a/lmdeploy/pytorch/check_env/model.py
+++ b/lmdeploy/pytorch/check_env/model.py
@@ -57,7 +57,13 @@ def check_dtype(self, config):
if not is_bf16_supported(device_type):
logger.warning('Device does not support bfloat16.')
except Exception as e:
- message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.')
+ message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
+ self.log_and_exit(e, 'Model', message=message)
+
+ try:
+ model_config.check_env_func(device_type)
+ except Exception as e:
+ message = (f'Checking failed with error {e}.')
self.log_and_exit(e, 'Model', message=message)
def check(self):
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index ac3459e045..1a275cbf55 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
-from dataclasses import dataclass
-from typing import Any, Dict, List, Literal
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Literal, Tuple
import torch
@@ -86,6 +86,8 @@ class CacheConfig:
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
device_type: str = 'cuda'
+ num_state_caches: int = None
+ states_shapes: List[Tuple] = field(default_factory=list)
# For PD Disaggregation
role: EngineRole = EngineRole.Hybrid
@@ -183,6 +185,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
_override_hf_config(hf_config, k, v)
+def _default_check_env(device: str):
+ pass
+
+
@dataclass
class ModelConfig:
"""Config of model."""
@@ -208,6 +214,13 @@ class ModelConfig:
dllm_mask_token: int = 0
dllm_block_length: int = None
+ # added for qwen3_next
+ # could used for any SSM model.
+ states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)
+
+ # check env for model-device combination
+ check_env_func: Callable = _default_check_env
+
def get_head_size(self):
"""Get head size."""
return self.head_dim
diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py
index e30ae7c089..ce3aece9df 100644
--- a/lmdeploy/pytorch/configurations/default.py
+++ b/lmdeploy/pytorch/configurations/default.py
@@ -37,6 +37,8 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
eos_token_id=hf_config.eos_token_id,
sliding_window=sliding_window,
head_dim=head_dim,
+ k_head_dim=head_dim,
+ v_head_dim=head_dim,
vocab_size=hf_config.vocab_size,
llm_config=hf_config,
)
diff --git a/lmdeploy/pytorch/configurations/qwen3_next.py b/lmdeploy/pytorch/configurations/qwen3_next.py
new file mode 100644
index 0000000000..e5bf81e80d
--- /dev/null
+++ b/lmdeploy/pytorch/configurations/qwen3_next.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .builder import AutoModelConfigBuilder
+from .default import DefaultModelConfigBuilder
+
+
+def _check_env_qwen3_next(device: str):
+ """Check env for qwen3 next."""
+ if device != 'cuda':
+ return
+
+ # check cuda
+ try:
+ import causal_conv1d # noqa: F401
+ except ImportError:
+ raise ImportError('Qwen3-Next cuda support requires https://github.com/Dao-AILab/causal-conv1d.')
+
+ try:
+ import fla # noqa: F401
+ except ImportError:
+ raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')
+
+
+class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):
+
+ @classmethod
+ def condition(cls, hf_config):
+ """config."""
+ return hf_config.model_type == 'qwen3_next'
+
+ @classmethod
+ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
+ """build."""
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)
+
+ # update num layers
+ num_layers = cfg.num_layers
+ num_full_layers = num_layers // hf_config.full_attention_interval
+ num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)
+ cfg.num_layers = num_full_layers
+
+ # set state shapes
+ head_k_dim = hf_config.linear_key_head_dim
+ head_v_dim = hf_config.linear_value_head_dim
+ num_v_heads = hf_config.linear_num_value_heads // tp
+ num_k_heads = hf_config.linear_num_key_heads // tp
+ key_dim = head_k_dim * num_k_heads
+ value_dim = head_v_dim * num_v_heads
+ conv_dim = key_dim * 2 + value_dim
+ conv_kernel_size = hf_config.linear_conv_kernel_dim
+
+ conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
+ recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
+ dtype = torch.bfloat16
+ cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
+ cfg.check_env_func = _check_env_qwen3_next
+ return cfg
diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py
index d8ec198349..96c0cb63c3 100644
--- a/lmdeploy/pytorch/engine/cache_engine.py
+++ b/lmdeploy/pytorch/engine/cache_engine.py
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import json
+import math
+from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple
import torch
@@ -20,6 +22,23 @@
logger = get_logger('lmdeploy')
+def round_up(x: int, alignment: int) -> int:
+ """Round up x to the nearest multiple of alignment."""
+ return ((x + alignment - 1) // alignment) * alignment
+
+
+@dataclass
+class CacheDesc:
+ """Cache description."""
+ shape: List[int]
+ dtype: torch.dtype
+ alignment: int = 256
+
+ def __post_init__(self):
+ self.size = math.prod(self.shape) * self.dtype.itemsize
+ self.aligned_size = round_up(self.size, self.alignment)
+
+
class CacheEngine:
"""Host and Device memory maintainer.
@@ -384,3 +403,77 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote
))
""" Metheds for PD Disaggregation End. """
+
+
+class StateCacheEngine:
+ """Cache engine for state cache."""
+
+ def __init__(self, cache_config: CacheConfig):
+ self.cache_config = cache_config
+ self.mem_pool, self._state_caches = self.allocate_caches(num_caches=cache_config.num_state_caches,
+ state_shapes=cache_config.states_shapes,
+ device='cuda')
+
+ @staticmethod
+ def allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device):
+ """Allocate cache implement."""
+
+ if len(state_shapes) == 0 or num_caches == 0:
+ return torch.empty((0, 0), dtype=torch.uint8, device=device), []
+
+ cache_descs = [CacheDesc(shape, dtype) for shape, dtype in state_shapes]
+
+ # get mempool size
+ mem_pool_size = 0
+ for desc in cache_descs:
+ mem_pool_size += desc.aligned_size
+
+ # create pool
+ mem_pool = torch.zeros((num_caches, mem_pool_size), dtype=torch.uint8, device=device)
+
+ # slice caches
+ caches = []
+ remain_pool = mem_pool
+ for desc in cache_descs:
+ cache = remain_pool[:, :desc.size].view(desc.dtype).view((num_caches, *desc.shape))
+ remain_pool = remain_pool[:, desc.aligned_size:]
+ caches.append(cache)
+ return mem_pool, caches
+
+ @staticmethod
+ def get_cache_state_size(state_shapes: List[Tuple[Tuple[int], torch.dtype]]) -> int:
+ """Get the required cache size of the state cache.
+
+ Args:
+ state_shapes (List[Tuple[Tuple[int], torch.dtype]]): The shapes and dtypes of the states.
+
+ Return:
+ int: Required memory size in bytes.
+ """
+ mem_pool, _ = StateCacheEngine.allocate_caches(num_caches=1, state_shapes=state_shapes, device='meta')
+ return mem_pool.numel() * mem_pool.element_size()
+
+ @property
+ def state_caches(self):
+ """State caches."""
+ return self._state_caches
+
+ def init_caches(self, idx: torch.Tensor, mask: torch.Tensor):
+ """Initialize state caches.
+
+ idx: indices of caches to be initialized.
+ mask: mask to indicate which idx to be initialized.
+ """
+ if idx is None:
+ return
+
+ if len(self._state_caches) <= 0:
+ return
+
+ num_caches = self.cache_config.num_state_caches
+
+ # get mask of all caches so we can perform inplace mask fill
+ cache_masks = torch.zeros((num_caches, ), dtype=torch.bool, device=idx.device)
+ cache_masks.index_copy_(0, idx, mask)
+ reshaped_mask = cache_masks.view((-1, ) + (1, ) * (self.mem_pool.dim() - 1))
+ self.mem_pool.masked_fill_(reshaped_mask, 0)
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index 26badf256e..2b65230c9d 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -807,6 +807,11 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs)
model_inputs.vision_inputs = vision_model_inputs
+ # ssm
+ if len(self.cache_config.states_shapes) > 0:
+ state_offsets = torch.tensor([msg.logical_state for msg in messages])
+ model_inputs.state_offsets = state_offsets
+
return model_inputs
def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor,
diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py
index 9e50843a80..ba8333520b 100644
--- a/lmdeploy/pytorch/engine/executor/base.py
+++ b/lmdeploy/pytorch/engine/executor/base.py
@@ -135,7 +135,7 @@ def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_
# estimate runtime mem size
runtime_cache_size = int((max_prefill_token_num + max_batches * 2) * vocal_size * 2)
num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count
- if int(num_available) // cache_block_size >= 16:
+ if cache_block_size == 0 or int(num_available) // cache_block_size >= 16:
break
max_prefill_token_num = max_prefill_token_num // 2
return runtime_cache_size, max_prefill_token_num
@@ -153,16 +153,48 @@ def _adjust_block_size(self):
f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.' # noqa
)
+ def _get_state_cache_mem(self):
+ """Get state cache mem usage."""
+ cache_config = self.cache_config
+ if len(cache_config.states_shapes) == 0:
+ return 0
+
+ from lmdeploy.pytorch.engine.cache_engine import StateCacheEngine
+
+ num_state_caches = cache_config.num_state_caches
+ if num_state_caches is None:
+ # add more caches for eviction
+ # TODO: Share memory between state cache and pageable cache
+ num_state_caches = int(cache_config.max_batches + 8)
+ cache_config.num_state_caches = num_state_caches
+
+ mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes)
+ mems *= num_state_caches
+
+ if cache_config.enable_prefix_caching:
+ cache_config.enable_prefix_caching = False
+ logger.warning('Prefix caching has not been support for state space model.')
+
+ return mems
+
def update_configs(self):
"""Update cache config."""
self._adjust_block_size()
cache_config = self.cache_config
model_config = self.model_config
+ cache_config.states_shapes = model_config.states_shapes
+
+ # get free mems
free_mems = self.gather_free_mem()
free_mem = min(free_mems)
logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')
- vocal_size = self.model_config.vocab_size
+ # get state cache size
+ state_cache_mem = self._get_state_cache_mem()
+ free_mem = free_mem - state_cache_mem
+ assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.'
+
+ vocal_size = self.model_config.vocab_size
tp = self.dist_config.attn_config.tp
cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp,
cache_config.quant_policy)
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 5dd6b561cb..f1df38e43c 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -29,7 +29,7 @@
from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
from ..utils import get_gpu_memory
from ..weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights
-from .cache_engine import CacheEngine
+from .cache_engine import CacheEngine, StateCacheEngine
from .guided_process import GuidedDecodingManager
from .logits_process import FusedLogitsProcessor, SamplingInputs
@@ -222,6 +222,7 @@ def model_forward(
model: torch.nn.Module,
inputs: ModelInputs,
cache_engine: CacheEngine,
+ state_cache_engine: StateCacheEngine,
stream: torch.cuda.Stream = None,
):
"""Perform model forward."""
@@ -233,6 +234,7 @@ def model_forward(
inputs=inputs,
model_config=cache_engine.model_config,
kv_caches=cache_engine.gpu_cache,
+ state_caches=state_cache_engine.state_caches,
kv_quant_policy=cache_engine.cache_config.quant_policy,
)
with ctx_mgr.context(context):
@@ -349,6 +351,7 @@ def __init__(self,
self.patched_model = None
self.cache_engine = None
+ self.state_cache_engine = None
self.profiler: AgentProfiler = None
try:
self.guided_decoding_manager = GuidedDecodingManager(self.tokenizer, model_config.vocab_size)
@@ -394,7 +397,10 @@ def get_free_mem(self):
def warmup(self):
"""warmup."""
- # TODO: disable for now, do not remove the comments.
+ from lmdeploy.pytorch.envs import skip_warmup
+ if skip_warmup:
+ return
+
with self.all_context():
max_batches = self.cache_config.max_batches
num_tokens = max_batches
@@ -689,6 +695,10 @@ async def __prepare_dp():
logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.')
return
+ if not is_decoding:
+ # init state cache for first time prefill
+ # I don't know if this is necessary...
+ self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0)
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
for idx in range(loop_count):
# inference
@@ -968,12 +978,14 @@ def build_cache_engine(self):
tp_rank=self.tp_rank,
world_size=tp,
cache_stream=self.cache_stream)
+ self.state_cache_engine = StateCacheEngine(self.cache_config)
def _forward_impl(self, inputs: ModelInputs):
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
+ state_cache_engine=self.state_cache_engine,
stream=self.stream,
)
return output
diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py
index d5d08362d9..f36aeced3e 100644
--- a/lmdeploy/pytorch/envs.py
+++ b/lmdeploy/pytorch/envs.py
@@ -131,6 +131,9 @@ def _patched_get_env(
os.getenv('DG_JIT_DEBUG', '0')
os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', '0')
+ # model agent
+ skip_warmup = env_to_bool('LMD_SKIP_WARMUP', False)
+
def get_all_envs():
"""Get all environment variables."""
diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py
index a4941c36af..0c61e77d4a 100644
--- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py
@@ -506,19 +506,27 @@ def _kernel_meta_sm8x(BLOCK_DMODEL: int, BLOCK_H: int):
def _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int):
"""Kernel meta default."""
- return _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)
+ num_warps = 4
+ if BLOCK_DMODEL * BLOCK_H > 4096:
+ num_stages = 2
+ else:
+ num_stages = 3
+ return num_warps, num_stages
-def _get_split_k(device_idx: int, head_grid: int, batch_size: int):
+def _get_split_k(device_idx: int, head_grid: int, batch_size: int, num_warps: int):
"""Get split k."""
props = get_device_props(device_idx)
num_sm = props['multi_processor_count']
# estimated occupancy 12.5%
warps_per_sm = props['warps_per_sm'] // 8
+ cta_per_sm = triton.cdiv(warps_per_sm, num_warps)
+ cta_per_device = num_sm * cta_per_sm
- SPLIT_K = triton.cdiv(num_sm * warps_per_sm // head_grid, triton.next_power_of_2(batch_size))
+ SPLIT_K = triton.cdiv(cta_per_device // head_grid, triton.next_power_of_2(batch_size))
SPLIT_K = 1 << (SPLIT_K.bit_length() - 1)
- SPLIT_K = max(min(SPLIT_K, 64), 4)
+ max_split = 1 << (num_sm.bit_length() - 1)
+ SPLIT_K = max(min(SPLIT_K, max_split), 4)
return SPLIT_K
@@ -616,7 +624,14 @@ def _get_block_d(Lk):
TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H)
grid_1 = TILES_PER_GROUP * num_kv_heads
- SPLIT_K = _get_split_k(q.device.index, grid_1, batch)
+ if _nv_cap[0] < 8:
+ num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)
+ elif _nv_cap[0] < 9:
+ num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H)
+ else:
+ num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H)
+
+ SPLIT_K = _get_split_k(q.device.index, grid_1, batch, num_warps)
if quant_policy != 4:
acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32)
@@ -629,13 +644,6 @@ def _get_block_d(Lk):
batch,
)
- if _nv_cap[0] < 8:
- num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)
- elif _nv_cap[0] < 9:
- num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H)
- else:
- num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H)
-
if quant_policy > 0:
_fwd_grouped_split_quant_kernel[grid](q,
k,
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index 101ed62546..ec5e6098b8 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -483,6 +483,7 @@ class SchedulerSequence:
num_new_tokens: int = 0
sampling_param: SamplingParam = field(default_factory=SamplingParam)
logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks)
+ logical_state: int = -1
adapter_name: str = None
arrive_time: float = 0.0
output_start_pos: int = 0
diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py
index 13e35fd1ae..7b3a09de3d 100644
--- a/lmdeploy/pytorch/model_inputs.py
+++ b/lmdeploy/pytorch/model_inputs.py
@@ -144,6 +144,7 @@ class ModelInputs:
model_metas: List[Dict[str, Any]] = None
dp_meta: 'DPMeta' = None
enable_microbatch: bool = False
+ state_offsets: torch.LongTensor = None
def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None):
"""Update input ids."""
@@ -257,6 +258,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
model_metas=self.model_metas,
cross_length=cross_length,
history_cross_length=history_cross_length,
+ state_offsets=self.state_offsets,
)
ret.append(inp)
history_cross_length = cross_length
@@ -322,6 +324,10 @@ class StepContext:
dp_meta: DPMeta = None
enable_microbatch: bool = False
+ # states for ssm
+ state_caches: List = None
+ state_offsets: torch.LongTensor = None
+
_outputs: Dict = field(default_factory=dict)
@classmethod
@@ -330,6 +336,7 @@ def new(
inputs: ModelInputs,
model_config: ModelConfig,
kv_caches: List = None,
+ state_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
"""Build step context.
@@ -389,6 +396,8 @@ def new(
cross_kv_seqlens=cross_kv_seqlens,
dp_meta=inputs.dp_meta,
enable_microbatch=inputs.enable_microbatch,
+ state_caches=state_caches,
+ state_offsets=inputs.state_offsets,
)
ret = get_backend().update_step_context(ret)
@@ -454,6 +463,7 @@ def build_context(
inputs: ModelInputs,
model_config: ModelConfig,
kv_caches: List = None,
+ state_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
"""Build context."""
@@ -461,6 +471,7 @@ def build_context(
inputs,
model_config,
kv_caches,
+ state_caches,
kv_quant_policy,
)
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index 498e2c6554..181418adf7 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -225,6 +225,11 @@
'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM',
})
+# qwen3 next model
+MODULE_MAP.update({
+ 'Qwen3NextForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_next.Qwen3NextForCausalLM',
+})
+
# SDAR
MODULE_MAP.update({
'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM',
diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py
new file mode 100644
index 0000000000..128afe72b4
--- /dev/null
+++ b/lmdeploy/pytorch/models/qwen3_next.py
@@ -0,0 +1,1085 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+import lmdeploy.pytorch.distributed as dist
+from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
+from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight
+
+from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
+
+
+class GatedDeltaMeta:
+
+ def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tensor, attn_metadata: Any):
+ self.num_tokens = num_tokens
+ self.is_decoding = attn_metadata.is_decoding
+ self.cu_seqlens = attn_metadata.cu_seqlens_q
+ device = self.cu_seqlens.device
+
+ # get seq_idx (1, num_tokens)
+ seqlens = attn_metadata.q_seqlens
+ batch_size = seqlens.numel()
+ batch_idx = torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ self.seq_idx = torch.repeat_interleave(batch_idx, seqlens, output_size=num_tokens)[None]
+
+ # conv_idx
+ range_idx = torch.arange(-conv_kernel_size, 0, device=device)
+ self.conv_idx = self.cu_seqlens[1:, None] + range_idx[None]
+ self.conv_idx = self.conv_idx.clamp_min(0)
+
+ # state_ids, fill invalid state with state_ids[0]
+ self.valid_state = state_ids >= 0
+ self.state_ids = torch.where(self.valid_state, state_ids, state_ids[0])
+ self.state_ids = self.state_ids.clamp(0)
+
+
+class CausalConv1dFunc:
+
+ def __init__(self, activation: str = 'silu'):
+ try:
+ import causal_conv1d
+ self.causal_conv1d_fn = causal_conv1d.causal_conv1d_fn
+ self.causal_conv1d_update = causal_conv1d.causal_conv1d_update
+ except Exception:
+ raise RuntimeError(
+ 'causal_conv1d is not installed, please refer to https://github.com/Dao-AILab/causal-conv1d')
+ self.activation = activation
+
+ def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor,
+ gated_delta_meta: GatedDeltaMeta):
+ """
+ x: (b, seqlen, dim)
+ seqlen: (b)
+ out: (b, seqlen, dim)
+ conv_state: (b, dim, kernel_size)
+ """
+ seq_idx = gated_delta_meta.seq_idx
+ conv_idx = gated_delta_meta.conv_idx
+
+ assert x.dim() == 3
+ x = x.transpose(-2, -1)
+ if weight.dim() == 3:
+ assert weight.size(1) == 1
+ weight = weight[:, 0]
+
+ # fill conv state
+ batch_size = conv_state.size(0)
+ conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1)
+ torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=conv_state)
+
+ out = self.causal_conv1d_fn(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ return_final_states=False,
+ activation=self.activation,
+ )
+
+ out = out.transpose(-2, -1)
+
+ # store conv_state
+ return out, conv_state
+
+ def conv1d_update(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor):
+ if weight.dim() == 3:
+ assert weight.size(1) == 1
+ weight = weight[:, 0]
+ out = self.causal_conv1d_update(x[0], conv_state, weight, bias, activation=self.activation)
+ return out[None], conv_state
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ conv_state: torch.Tensor,
+ gated_delta_meta: GatedDeltaMeta,
+ ):
+ if gated_delta_meta.is_decoding:
+ return self.conv1d_update(x, weight, bias, conv_state)
+ return self.conv1d_func(x, weight, bias, conv_state, gated_delta_meta=gated_delta_meta)
+
+
+class GatedDelta:
+
+ def __init__(self, use_qk_l2norm_in_kernel: bool = True):
+ try:
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
+ self.chunk_gated_delta_rule = chunk_gated_delta_rule
+ self.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule
+ except Exception:
+ raise RuntimeError(
+ 'fla is not installed, please refer to https://github.com/fla-org/flash-linear-attention')
+ self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
+
+ def __call__(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ recurrent_state: torch.Tensor,
+ gated_delta_meta: GatedDeltaMeta,
+ ):
+ """call."""
+ is_decoding = gated_delta_meta.is_decoding
+ cu_seqlens = gated_delta_meta.cu_seqlens
+
+ if not is_decoding:
+ core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=recurrent_state,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,
+ cu_seqlens=cu_seqlens,
+ )
+ else:
+ # qkvgb (1, seqlen, ...) -> (seqlen, 1, ...)
+ core_attn_out, last_recurrent_state = self.fused_recurrent_gated_delta_rule(
+ query[0, :, None],
+ key[0, :, None],
+ value[0, :, None],
+ g=g[0, :, None],
+ beta=beta[0, :, None],
+ initial_state=recurrent_state,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,
+ )
+ # out (seqlen, 1, ...) -> (1, seqlen, ...)
+ core_attn_out = core_attn_out[None, :, 0]
+ return core_attn_out, last_recurrent_state
+
+
+def build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs):
+ from fla.modules import FusedRMSNormGated
+ return FusedRMSNormGated(hidden_size, eps=eps, **kwargs)
+
+
+class CausalConv1d(nn.Module):
+ """Causal conv1d wrapper."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]],
+ groups: int = 1,
+ bias: bool = True,
+ split=None,
+ device=None,
+ dtype=None,
+ ):
+ super().__init__()
+ tp, rank = get_tp_world_rank()
+ self.tp = tp
+ self.rank = rank
+ in_channels = in_channels // tp
+ out_channels = out_channels // tp
+ groups = groups // tp
+ assert len(split) == 3
+ self.split = split
+
+ weight, bias = self.make_weight(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ groups=groups,
+ bias=bias,
+ device=device,
+ dtype=dtype,
+ )
+
+ self.register_weight(weight, bias)
+ self.causal_conv1d_func = CausalConv1dFunc(activation='silu')
+
+ @staticmethod
+ def make_weight(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]],
+ groups: int = 1,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ):
+ weight_shape = (out_channels, in_channels // groups,
+ kernel_size if isinstance(kernel_size, int) else kernel_size[0])
+ bias_shape = (out_channels, ) if bias else None
+
+ weight = torch.empty(weight_shape, device=device, dtype=dtype)
+ if bias_shape is not None:
+ bias = torch.empty(bias_shape, device=device, dtype=dtype)
+ else:
+ bias = None
+ return weight, bias
+
+ def register_weight(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
+ self.register_parameter('weight', nn.Parameter(weight))
+ self.weight.weight_loader = self.weight_loader
+ if bias is not None:
+ self.register_parameter('bias', nn.Parameter(bias))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter('bias', None)
+
+ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
+ """Weight loader."""
+ q, k, v = loaded_weight.split(self.split, dim=0)
+ q = q.chunk(self.tp, dim=0)[self.rank]
+ k = k.chunk(self.tp, dim=0)[self.rank]
+ v = v.chunk(self.tp, dim=0)[self.rank]
+ loaded_weight = torch.cat([q, k, v], dim=0)
+ default_weight_loader(param, loaded_weight)
+
+ def forward(self, x: torch.Tensor, conv_state: torch.Tensor, gated_delta_meta: GatedDeltaMeta):
+ """forward."""
+ return self.causal_conv1d_func(x, self.weight, self.bias, conv_state, gated_delta_meta=gated_delta_meta)
+
+
+class Qwen3NextGatedDeltaNet(nn.Module):
+ """Gated deltanet."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_v_heads = config.linear_num_value_heads
+ self.num_k_heads = config.linear_num_key_heads
+ self.head_k_dim = config.linear_key_head_dim
+ self.head_v_dim = config.linear_value_head_dim
+ self.key_dim = self.head_k_dim * self.num_k_heads
+ self.value_dim = self.head_v_dim * self.num_v_heads
+ self.kv_ratio = self.num_v_heads // self.num_k_heads
+
+ self.conv_kernel_size = config.linear_conv_kernel_dim
+ self.layer_idx = layer_idx
+ self.activation = config.hidden_act
+ self.layer_norm_epsilon = config.rms_norm_eps
+
+ # QKV
+ self.conv_dim = self.key_dim * 2 + self.value_dim
+ self.conv1d = CausalConv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=False,
+ kernel_size=self.conv_kernel_size,
+ groups=self.conv_dim,
+ split=[self.key_dim, self.key_dim, self.value_dim],
+ dtype=dtype,
+ device=device,
+ )
+
+ # projection of the input hidden states
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
+ projection_size_ba = self.num_v_heads * 2
+ self.in_proj_qkvz = build_colwise_linear(self.hidden_size,
+ projection_size_qkvz,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+ # dirty patch to qkvz
+ self.in_proj_qkvz.weight.weight_loader = self.weight_loader_qkvz
+ self.in_proj_ba = build_colwise_linear(self.hidden_size,
+ projection_size_ba,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.make_params(self.num_v_heads, device=device)
+ self.A_log_exp = None
+
+ self.norm = build_rmsnorm_gated(self.head_v_dim,
+ eps=self.layer_norm_epsilon,
+ activation=self.activation,
+ dtype=dtype,
+ device=device)
+ self.out_proj = build_o_proj(self.value_dim,
+ self.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ self.gated_delta = GatedDelta()
+
+ def get_A_log_exp(self):
+ if self.A_log_exp is None:
+ self.A_log_exp = -self.A_log.float().exp()
+
+ return self.A_log_exp
+
+ def make_params(self, num_v_heads: int, device: torch.device):
+ tp, _ = get_tp_world_rank()
+ num_v_heads = num_v_heads // tp
+ A = torch.empty(num_v_heads, device=device).uniform_(0, 16)
+ dt_bias = torch.empty(num_v_heads, device=device).uniform_(0, 1)
+
+ self.register_parameter('A_log', nn.Parameter(torch.log(A)))
+ self.register_parameter('dt_bias', nn.Parameter(dt_bias))
+ self.A_log.weight_loader = self.weight_loader_a_dt
+ self.dt_bias.weight_loader = self.weight_loader_a_dt
+
+ def weight_loader_qkvz(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
+ """Weight loader qkvz."""
+ tp, rank = get_tp_world_rank()
+ split_arg_list_qkvz = [
+ self.head_k_dim,
+ self.head_k_dim,
+ (self.kv_ratio * self.head_v_dim),
+ (self.kv_ratio * self.head_v_dim),
+ ]
+ sum_split = sum(split_arg_list_qkvz)
+ loaded_weight = loaded_weight.unflatten(0, (-1, sum_split))
+ q, k, v, z = loaded_weight.split(split_arg_list_qkvz, dim=1)
+ q = q.chunk(tp, dim=0)[rank]
+ k = k.chunk(tp, dim=0)[rank]
+ v = v.chunk(tp, dim=0)[rank]
+ z = z.chunk(tp, dim=0)[rank]
+
+ loaded_weight = torch.cat([q, k, v, z], dim=1)
+ loaded_weight = loaded_weight.flatten(0, 1)
+ default_weight_loader(param, loaded_weight)
+
+ def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
+ """Weight loader."""
+ tp, rank = get_tp_world_rank()
+ loaded_weight = loaded_weight.chunk(tp, dim=0)[rank]
+ default_weight_loader(param, loaded_weight)
+
+ def fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor):
+ """Derives `query`, `key` and `value` tensors from `mixed_qkvz` and
+ `mixed_ba`."""
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
+ -1,
+ 2 * self.head_k_dim + 2 * self.head_v_dim * self.kv_ratio,
+ )
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (-1, 2 * self.kv_ratio)
+
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
+ split_arg_list_qkvz = [
+ self.head_k_dim,
+ self.head_k_dim,
+ (self.kv_ratio * self.head_v_dim),
+ (self.kv_ratio * self.head_v_dim),
+ ]
+ split_arg_list_ba = [self.kv_ratio, self.kv_ratio]
+ query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1)
+ b, a = torch.split(mixed_ba, split_arg_list_ba, dim=-1)
+ # [..., ng, np/ng * hn] -> [..., np, hn]
+ value = value.reshape(*value.shape[:-2], -1, self.head_v_dim)
+ z = z.reshape(*z.shape[:-2], -1, self.head_v_dim)
+ b = b.reshape(*b.shape[:-2], -1)
+ a = a.reshape(*a.shape[:-2], -1)
+ return query, key, value, z, b, a
+
+ def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):
+ """Load states from cache."""
+ state_ids = gated_delta_meta.state_ids
+ conv_cache, recurrent_cache = past_key_value[:2]
+
+ return conv_cache.index_select(0, state_ids), recurrent_cache.index_select(0, state_ids)
+
+ def _store_state(self, conv_state: torch.Tensor, recurrent_state: torch.Tensor,
+ past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):
+ """Store states to cache."""
+ conv_cache, recurrent_cache = past_key_value[:2]
+ state_ids = gated_delta_meta.state_ids
+ valid_state = gated_delta_meta.valid_state
+
+ # fill invalid state with state[0]
+ conv_dim = conv_state.dim()
+ recurrent_dim = recurrent_state.dim()
+ conv_state = torch.where(valid_state.view(-1, *[1] * (conv_dim - 1)), conv_state, conv_state[:1])
+ recurrent_state = torch.where(valid_state.view(-1, *[1] * (recurrent_dim - 1)), recurrent_state,
+ recurrent_state[:1])
+
+ conv_cache = conv_cache.index_copy_(0, state_ids, conv_state)
+ recurrent_cache = recurrent_cache.index_copy_(0, state_ids, recurrent_state.to(recurrent_cache.dtype))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ past_key_value: Tuple[torch.Tensor, torch.Tensor],
+ gated_delta_meta: GatedDeltaMeta,
+ ):
+ """forward."""
+
+ # load states
+ conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta)
+
+ # inputs proj
+ projected_states_qkvz = self.in_proj_qkvz(hidden_states)
+ projected_states_ba = self.in_proj_ba(hidden_states)
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
+ query, key, value = (x.reshape(*x.shape[:-2], -1) for x in (query, key, value))
+
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
+ mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta)
+
+ tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1)
+ query, key, value = torch.split(
+ mixed_qkv,
+ [
+ self.key_dim // tp,
+ self.key_dim // tp,
+ self.value_dim // tp,
+ ],
+ dim=-1,
+ )
+ query = query.unflatten(-1, (-1, self.head_k_dim))
+ key = key.unflatten(-1, (-1, self.head_k_dim))
+ value = value.unflatten(-1, (-1, self.head_v_dim))
+
+ beta = b.sigmoid()
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
+ g = self.get_A_log_exp() * F.softplus(a.float() + self.dt_bias)
+ if self.kv_ratio > 1:
+ query = query.repeat_interleave(self.kv_ratio, dim=-2)
+ key = key.repeat_interleave(self.kv_ratio, dim=-2)
+
+ core_attn_out, recurrent_state = self.gated_delta(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ recurrent_state=recurrent_state,
+ gated_delta_meta=gated_delta_meta,
+ )
+
+ # store states
+ self._store_state(conv_state, recurrent_state, past_key_value, gated_delta_meta)
+
+ z_shape_og = z.shape
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
+ z = z.reshape(-1, z.shape[-1])
+ core_attn_out = self.norm(core_attn_out, z)
+ core_attn_out = core_attn_out.reshape(z_shape_og)
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
+
+ output = self.out_proj(core_attn_out)
+ return output
+
+
+class Qwen3NextAttention(nn.Module):
+ """Rewrite module of Qwen3MoeAttention."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ num_heads = config.num_attention_heads
+ num_key_value_heads = config.num_key_value_heads
+ hidden_size = config.hidden_size
+ head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
+ self.head_dim = head_dim
+ num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
+
+ # packed qkv
+ # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
+ self.qkv_proj = build_qkv_proj(
+ hidden_size,
+ num_q_heads=num_heads * 2,
+ num_kv_heads=num_key_value_heads,
+ head_size=head_dim,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ num_replicate_kv_heads=num_replicate_kv_heads,
+ dtype=dtype,
+ device=device,
+ )
+
+ # rotary embedding
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+
+ # attention
+ self.attn_fwd = Attention(
+ num_heads,
+ head_dim,
+ num_kv_heads=num_key_value_heads,
+ v_head_size=head_dim,
+ )
+
+ # o_proj
+ self.o_proj = build_o_proj(num_heads * head_dim,
+ hidden_size,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ # q, k norm
+ self.q_norm = RMSNorm(head_dim,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+ self.k_norm = RMSNorm(head_dim,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attn_metadata: Any = None,
+ ):
+ """Rewrite of LlamaAttention.forward."""
+ # qkv proj
+ qkv_states = self.qkv_proj(hidden_states)
+ # (-1, heads, head_dim)
+ qkv_states = qkv_states.flatten(0, -2)
+ query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)
+ query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1)
+
+ # apply q, k norm
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ # apply rotary embedding
+ cos, sin = rotary_pos_emb
+ query_states, key_states = self.apply_rotary_pos_emb(
+ query_states,
+ key_states,
+ cos,
+ sin,
+ inplace=True,
+ )
+
+ # attention
+ attn_output = self.attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ past_key_value[0],
+ past_key_value[1],
+ attn_metadata,
+ k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
+ v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
+ inplace=True,
+ )
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
+ gate = gate.reshape(*hidden_states.shape[:-1], -1)
+ attn_output = attn_output * gate.sigmoid()
+
+ # o proj
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class Qwen3NextMLP(nn.Module):
+ """mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ intermediate_size: int = None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ is_tp: bool = True,
+ all_reduce: bool = True):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ if intermediate_size is None:
+ intermediate_size = config.intermediate_size
+ # gate up
+ self.gate_up_proj = build_merged_colwise_linear(
+ config.hidden_size,
+ [intermediate_size, intermediate_size],
+ bias=False,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=is_tp,
+ )
+
+ # silu and mul
+ self.act_fn = SiluAndMul(inplace=True)
+
+ # down
+ self.down_proj = build_rowwise_linear(intermediate_size,
+ config.hidden_size,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ all_reduce=all_reduce)
+
+ def forward(self, x):
+ """forward."""
+ gate_up = self.gate_up_proj(x)
+ act = self.act_fn(gate_up)
+ return self.down_proj(act)
+
+
+class Qwen3NextSparseMoeBlock(nn.Module):
+ """Moe block."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ # TODO: zhouxinyu, determine modules_to_not_convert from config file
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.layer_idx = layer_idx
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.moe_intermediate_size
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.norm_topk_prob = config.norm_topk_prob
+ self.renormalize = self.norm_topk_prob
+
+ self.gate = build_rowwise_linear(
+ self.hidden_dim,
+ self.num_experts,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ )
+
+ self.softmax_topk = SoftmaxTopK(self.top_k)
+
+ self.experts = build_fused_moe(
+ self.hidden_dim,
+ self.ffn_dim,
+ self.num_experts,
+ top_k=self.top_k,
+ renormalize=self.renormalize,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ all_reduce=False,
+ layer_idx=layer_idx,
+ )
+
+ self.shared_expert = Qwen3NextMLP(
+ config=config,
+ intermediate_size=config.shared_expert_intermediate_size,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ all_reduce=False,
+ )
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)
+
+ # get all reduce
+ dist_ctx = get_dist_manager().current_context()
+ dp = dist_ctx.dp
+ world_size = dist_ctx.world_size
+ if dp == 1 and world_size > 1:
+ self._all_reduce = True
+ else:
+ self._all_reduce = False
+
+ def forward(self, hidden_states: torch.Tensor):
+ """forward."""
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ router_logits = self.gate(hidden_states)
+ topk_weights, topk_ids = self.softmax_topk(router_logits)
+ out_states = self.experts(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ )
+
+ shared_states = self.shared_expert(hidden_states)
+ shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states
+
+ out_states += shared_states
+ out_states = out_states.reshape(batch_size, sequence_length, -1)
+
+ if self._all_reduce:
+ dist.all_reduce(out_states)
+ return out_states
+
+
+class Qwen3NextDecoderLayer(nn.Module):
+ """Decoder layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ quantization_config = getattr(config, 'quantization_config', None)
+
+ # build attention layer
+ self.layer_type = config.layer_types[layer_idx]
+ if self.layer_type == 'linear_attention':
+ self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx, dtype=dtype, device=device)
+ elif self.layer_type == 'full_attention':
+ self.self_attn = Qwen3NextAttention(config, dtype=dtype, device=device)
+
+ # build MLP
+ if (layer_idx not in config.mlp_only_layers) and (config.num_experts
+ > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):
+ self.mlp = Qwen3NextSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device)
+ else:
+ self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)
+
+ # build input layer norm
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ # build attention layer norm
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[List[torch.FloatTensor]],
+ residual: Optional[torch.Tensor],
+ attn_metadata: Any,
+ gated_delta_meta: GatedDeltaMeta,
+ ):
+
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ if self.layer_type == 'linear_attention':
+ hidden_states = self.linear_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ gated_delta_meta=gated_delta_meta,
+ )
+ elif self.layer_type == 'full_attention':
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ attn_metadata=attn_metadata,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ outputs = (hidden_states, residual)
+ return outputs
+
+
+class Qwen3NextModel(nn.Module):
+ """Qwen3 next model."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=dtype,
+ device=device)
+
+ # build all decode layers
+ # TODO: use full config.num_hidden_layers
+ self.layers = nn.ModuleList([
+ Qwen3NextDecoderLayer(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(self.config.num_hidden_layers)
+ ])
+
+ # build norm
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
+
+ # build rotary embedding
+ self.rotary_emb = build_rotary_embedding_from_config(config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ position_ids: torch.LongTensor,
+ past_key_values: List[torch.FloatTensor],
+ attn_metadata: Any,
+ state_ids: torch.Tensor,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Rewrite of LlamaModel.forward."""
+
+ # token embedding
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ # rotary embedding
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
+ cos, sin = cos[0], sin[0]
+ rotary_pos_emb = (cos, sin)
+
+ # make seq_idx
+ gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids,
+ attn_metadata)
+
+ # decoding
+ residual = None
+ for idx, decoder_layer in enumerate(self.layers):
+ hidden_states, residual = decoder_layer(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_values[idx],
+ residual=residual,
+ attn_metadata=attn_metadata,
+ gated_delta_meta=gated_delta_meta,
+ )
+
+ # norm
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.embed_tokens
+
+
+class Qwen3NextForCausalLM(nn.Module, CudaGraphMixin):
+ """ModelForCausalLM."""
+
+ packed_modules_mapping = {
+ 'qkv_proj': [
+ 'q_proj',
+ 'k_proj',
+ 'v_proj',
+ ],
+ 'gate_up_proj': [
+ 'gate_proj',
+ 'up_proj',
+ ],
+ }
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+ # build model
+ self.model = Qwen3NextModel(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ inputs_embeds: torch.Tensor = None,
+ state_ids: torch.Tensor = None,
+ **kwargs,
+ ):
+ """Model forward, return logits."""
+ hidden_states = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ state_ids=state_ids,
+ )
+ return hidden_states
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """Compute logits of the model output."""
+ return self.lm_head(hidden_states)
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """Prepare input."""
+ # get input_ids, position_ids and attention metadatas
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ # make past_key_values
+ state_caches = list(cache.transpose(0, 1) for cache in context.state_caches)
+ state_caches = list(zip(state_caches[0], state_caches[1]))
+ past_key_values = list(past_key_values)
+ new_past_key_values = []
+ for layer_type in self.config.layer_types:
+ if layer_type == 'linear_attention':
+ new_past_key_values.append(state_caches.pop(0))
+ elif layer_type == 'full_attention':
+ new_past_key_values.append(past_key_values.pop(0))
+
+ # process vision embeddings
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)
+
+ # inputs of forward
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=new_past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ state_ids=context.state_offsets,
+ )
+
+ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """Make cudagraph buffers from forward inputs."""
+ max_batchs = graph_meta.max_batchs
+ device = graph_meta.device
+
+ input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
+ state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device)
+ input_buffers['state_ids'] = state_ids
+
+ return input_buffers
+
+ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """Fill cudagraph buffers from forward inputs."""
+ input_buffers = graph_meta.input_buffers
+
+ new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
+ state_ids = kwargs['state_ids']
+ input_buffers['state_ids'].fill_(-1)
+ input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids)
+ new_inputs['state_ids'] = input_buffers['state_ids']
+
+ return new_inputs
+
+ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
+ expert_params_mapping: List):
+ """Load weight experts."""
+ # load fused weights
+ for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Load weights."""
+
+ def __skip_layers(name):
+ """We might change the number of layers so we can debug the model
+ with less gpus."""
+ import re
+ if '.layers.' not in name:
+ return False
+ matches = re.findall(r'\.layers\.(\d+)\.', name)
+ layer_id = int(matches[0])
+ return layer_id >= self.config.num_hidden_layers
+
+ # modify from vllm
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ('.gate_up_proj', '.gate_proj', 0),
+ ('.gate_up_proj', '.up_proj', 1),
+ ]
+
+ # expert map
+ num_experts = self.config.num_experts
+ expert_params_mapping = []
+ for exp_id in range(num_experts):
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
+ expert_params_mapping += [gate_param, up_param, down_param]
+
+ rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']
+
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+
+ if __skip_layers(name):
+ continue
+
+ if 'mtp.' in name:
+ continue
+ if 'rotary_emb.inv_freq' in name:
+ continue
+ if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
+ continue
+ if self.config.tie_word_embeddings and 'lm_head.weight' in name:
+ continue
+
+ name = name.replace('.block_sparse_moe.', '.mlp.')
+ if '.experts' in name and '.shared_expert' not in name:
+ self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
+ else:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ for rms_norm_key in rms_norm_keys:
+ if rms_norm_key in name and 'weight' in name:
+ loaded_weight = loaded_weight + 1
+ break
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py
index 065aef97d1..5137f8f132 100644
--- a/lmdeploy/pytorch/models/utils/cudagraph.py
+++ b/lmdeploy/pytorch/models/utils/cudagraph.py
@@ -82,6 +82,10 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) ->
# create buffer for cross_attn_metadata here
input_buffers['fill_seqlens'] = torch.zeros(max_batches, dtype=torch.int64, device=device)
+ input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device)
+ input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0]
+ input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1]
+
return input_buffers
def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor,
@@ -107,7 +111,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens))
input_buffers['qkv_lens'].zero_()
+ input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs)
input_buffers['qkv_lens'][:, :batch_size] = qkv
+ input_buffers['cu_seqlens_q'][1:batch_size + 1] = input_buffers['q_seqlens'][:batch_size].cumsum(0)
+ input_buffers['cu_seqlens_k'][1:batch_size + 1] = input_buffers['kv_seqlens'][:batch_size].cumsum(0)
if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
if 'inputs_embeds' not in input_buffers:
@@ -121,6 +128,8 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
attn_metadata.q_start_loc = input_buffers['q_start_loc']
attn_metadata.q_seqlens = input_buffers['q_seqlens']
attn_metadata.kv_seqlens = input_buffers['kv_seqlens']
+ attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q']
+ attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k']
if getattr(self.config, 'use_flash_mla', False) is True:
import flash_mla
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32),
diff --git a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
index 20ef018824..3799d60a42 100644
--- a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
+++ b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
@@ -14,6 +14,8 @@ def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.block_manager = scheduler.block_manager
self.block_trie = scheduler.block_trie
+ self.state_manager = scheduler.state_manager
+ self.cache_config = scheduler.cache_config
def need_swap_in(self, seq: SchedulerSequence):
"""Sequence need swap in."""
diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
index 7dac755666..be0d09a5f9 100644
--- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
+++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
@@ -2,13 +2,23 @@
from typing import List
from ...messages import SchedulerSequence
+from ..scheduler import Scheduler
from .base_eviction_helper import BaseEvictionHelper
class RecomputeEvictionHelper(BaseEvictionHelper):
"""Recompute eviction."""
- def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):
+ def __init__(self, scheduler: Scheduler):
+ super().__init__(scheduler)
+
+ if len(self.cache_config.states_shapes) == 0:
+ self.evict_for_seq = self._evict_for_seq_default
+ else:
+ self.evict_for_seq = self._evict_for_ssm
+
+ def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence],
+ prealloc_size: int):
"""Evict seqs."""
block_manager = self.block_manager
block_trie = self.block_trie
@@ -46,3 +56,51 @@ def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSe
success = True
return success
+
+ def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):
+ """Evict seqs."""
+ block_manager = self.block_manager
+ state_manager = self.state_manager
+ block_trie = self.block_trie
+ num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size)
+ has_free_state = state_manager.get_num_free() > 0
+
+ if has_free_state and block_manager.get_num_free_gpu_blocks() >= num_required_blocks:
+ return True
+
+ success = False
+ while len(evictable_seqs) > 0:
+ evict_seq = evictable_seqs.pop(0)
+
+ # skip sequence with no blocks
+ if evict_seq.num_blocks == 0 and evict_seq.logical_state < 0:
+ continue
+
+ # free sequence
+ block_manager.free(evict_seq)
+ evict_seq.set_step(0)
+ state_manager.free(evict_seq)
+ has_free_state = True
+ num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
+ if num_req <= 0:
+ success = True
+ break
+
+ # clear cached prefix
+ block_trie.evict(num_req)
+ num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
+ if num_req <= 0:
+ success = True
+ break
+
+ if not has_free_state:
+ return False
+
+ # for empty evictable_seqs case
+ num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks()
+ if num_req > 0:
+ block_trie.evict(num_req)
+ if num_required_blocks <= block_manager.get_num_free_gpu_blocks():
+ success = True
+
+ return success
diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py
index e19bd18141..838688829f 100644
--- a/lmdeploy/pytorch/paging/scheduler.py
+++ b/lmdeploy/pytorch/paging/scheduler.py
@@ -13,6 +13,7 @@
from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta
from .block_manager import build_block_manager
from .block_trie import BlockTrie
+from .state_manager import StateManager
logger = get_logger('lmdeploy')
@@ -51,6 +52,8 @@ def __init__(self,
self.block_manager = build_block_manager(cache_config)
self.block_trie = BlockTrie(self.cache_config, self.block_manager)
+ self.state_manager = StateManager(self.cache_config.num_state_caches)
+ self.is_ssm = len(self.cache_config.states_shapes) > 0
self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type)
@@ -240,6 +243,8 @@ def _reorder_waiting():
# allocate session memory
self.block_manager.allocate(seq, prealloc_size)
+ if self.is_ssm:
+ self.state_manager.allocate(seq)
_to_running(seq)
seq.record_event(EventType.SCHEDULED)
@@ -335,6 +340,7 @@ def _remove_sequence(self, seq: SchedulerSequence):
seq (SchedulerSequence): sequence to remove
"""
self.block_manager.free(seq)
+ self.state_manager.free(seq)
seq.set_step(0)
seq.session.remove_sequence(seq)
diff --git a/lmdeploy/pytorch/paging/state_manager.py b/lmdeploy/pytorch/paging/state_manager.py
new file mode 100644
index 0000000000..6de882111b
--- /dev/null
+++ b/lmdeploy/pytorch/paging/state_manager.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+from lmdeploy.pytorch.messages import SchedulerSequence
+
+
+class StateAllocator:
+ """State allocator."""
+
+ def __init__(self, num_states: int):
+ self.num_states = num_states
+ self._free_states = np.arange(num_states, dtype=np.int64)
+ self._free_count = num_states
+
+ def allocate(self):
+ """allocate."""
+ if self.get_num_free() == 0:
+ raise RuntimeError('No free states.')
+ alloc_id = self._free_states[-self._free_count]
+ self._free_count -= 1
+ return alloc_id
+
+ def free(self, state_id: int):
+ """free."""
+ if self._free_count >= self.num_states:
+ raise RuntimeError('All states are free.')
+ self._free_count += 1
+ self._free_states[-self._free_count] = state_id
+
+ def get_num_free(self):
+ return self._free_count
+
+
+class StateManager:
+
+ def __init__(self, num_states: int):
+ if num_states is None:
+ num_states = 1
+ self.allocator = StateAllocator(num_states)
+
+ def is_allocated(self, seq: SchedulerSequence):
+ """Check if a sequence is allocated."""
+ return seq.logical_state >= 0
+
+ def allocate(self, seq: SchedulerSequence):
+ """Allocate states for a sequence."""
+ if self.is_allocated(seq):
+ return None
+ seq.logical_state = self.allocator.allocate()
+
+ def free(self, seq: SchedulerSequence):
+ """Free states for a sequence."""
+ if not self.is_allocated(seq):
+ return None
+ self.allocator.free(seq.logical_state)
+ seq.logical_state = -1
+
+ def get_num_free(self):
+ """Get num free."""
+ return self.allocator.get_num_free()
diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py
index 795220bd02..e7180bd926 100644
--- a/lmdeploy/pytorch/strategies/base/model_inputs.py
+++ b/lmdeploy/pytorch/strategies/base/model_inputs.py
@@ -26,6 +26,7 @@ def make_dummy_inputs(batch_size: int,
block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device)
num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)
local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device)
+ state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device)
return ModelInputs(
input_ids=input_ids,
@@ -38,6 +39,7 @@ def make_dummy_inputs(batch_size: int,
max_kv_seqlen=max_kv_seqlen,
sum_kv_seqlen=num_tokens,
local_adapter_ids=local_adapter_ids,
+ state_offsets=state_offsets,
)