diff --git a/README.md b/README.md index d204fddff6..de70ceca39 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • +
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • diff --git a/README_ja.md b/README_ja.md index 75d05390ad..4c7acc4792 100644 --- a/README_ja.md +++ b/README_ja.md @@ -117,6 +117,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • +
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index f6f10a5b42..aa8f6183b4 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -131,6 +131,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • +
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index aa28854d8a..6a9b66b73d 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -85,6 +85,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* | +| QWen3-Next | 80B | LLM | Yes | No | No | No | No | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 8e9e3fef20..7e9cfa648c 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -85,6 +85,7 @@ | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes | +| QWen3-Next | 80B | LLM | Yes | No | No | No | No | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..88a4f07098 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -91,6 +91,16 @@ def __init__( self.pool = pool self._graph: torch.cuda.CUDAGraph = None + def make_output_buffers(self, output): + """Make output buffers.""" + output_buffers = dict(logits=output) + return output_buffers + + def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]): + """Slice output.""" + num_tokens = inputs['input_ids'].size(-1) + return output_buffers['logits'][:, :num_tokens] + @record_function('capture_cudagraph') def capture(self, **kwargs): """Capture graph.""" @@ -102,7 +112,8 @@ def capture(self, **kwargs): current_stream = torch.cuda.current_stream() # warmup - self.model(**padded_kwargs) + warmup_output = self.model(**padded_kwargs) + warmup_buffers = self.make_output_buffers(warmup_output) self._graph = torch.cuda.CUDAGraph() # unsafe kernel call in other thread might invalid the capture @@ -110,21 +121,22 @@ def capture(self, **kwargs): with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'): output = self.model(**padded_kwargs) - output_buffers = dict(logits=output) + output_buffers = self.make_output_buffers(output) self.meta.output_buffers = output_buffers + output = self.slice_output(warmup_buffers, kwargs) return output @record_function('forward_cudagraph') def forward(self, **kwargs): """forward.""" - num_tokens = kwargs['input_ids'].size(-1) assert self._graph is not None self.model.fill_buffers_cudagraph(self.meta, **kwargs) context = self.ctx_mgr.current_context() self.model.update_context_cudagraph(self.meta, context) self._graph.replay() - output = self.meta.output_buffers['logits'][:, :num_tokens] + output_buffers = self.meta.output_buffers + output = self.slice_output(output_buffers, kwargs) return output def __del__(self): @@ -223,12 +235,14 @@ def __call__(self, **kwargs): pool=self.graph_pool_handle, model_config=self.model_config, device=self.device) - runner.capture(**kwargs) + output = runner.capture(**kwargs) self._runner_map[graph_key] = runner + # SSM would update the state in capture(warmup), replay the graph will leads unexpected state update. + return output else: runner = self._runner_map[graph_key] - output = runner.forward(**kwargs) - return output + output = runner.forward(**kwargs) + return output @record_function('prepare_inputs_for_generation') def prepare_inputs_for_generation( diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py index d1d096dbfe..3bd7f8cea1 100644 --- a/lmdeploy/pytorch/check_env/model.py +++ b/lmdeploy/pytorch/check_env/model.py @@ -57,7 +57,13 @@ def check_dtype(self, config): if not is_bf16_supported(device_type): logger.warning('Device does not support bfloat16.') except Exception as e: - message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.') + message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.') + self.log_and_exit(e, 'Model', message=message) + + try: + model_config.check_env_func(device_type) + except Exception as e: + message = (f'Checking failed with error {e}.') self.log_and_exit(e, 'Model', message=message) def check(self): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..1a275cbf55 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum -from dataclasses import dataclass -from typing import Any, Dict, List, Literal +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Literal, Tuple import torch @@ -86,6 +86,8 @@ class CacheConfig: enable_prefix_caching: bool = False quant_policy: Literal[0, 4, 8] = 0 device_type: str = 'cuda' + num_state_caches: int = None + states_shapes: List[Tuple] = field(default_factory=list) # For PD Disaggregation role: EngineRole = EngineRole.Hybrid @@ -183,6 +185,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]): _override_hf_config(hf_config, k, v) +def _default_check_env(device: str): + pass + + @dataclass class ModelConfig: """Config of model.""" @@ -208,6 +214,13 @@ class ModelConfig: dllm_mask_token: int = 0 dllm_block_length: int = None + # added for qwen3_next + # could used for any SSM model. + states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list) + + # check env for model-device combination + check_env_func: Callable = _default_check_env + def get_head_size(self): """Get head size.""" return self.head_dim diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py index e30ae7c089..ce3aece9df 100644 --- a/lmdeploy/pytorch/configurations/default.py +++ b/lmdeploy/pytorch/configurations/default.py @@ -37,6 +37,8 @@ def build(cls, hf_config, model_path: str = None, **kwargs): eos_token_id=hf_config.eos_token_id, sliding_window=sliding_window, head_dim=head_dim, + k_head_dim=head_dim, + v_head_dim=head_dim, vocab_size=hf_config.vocab_size, llm_config=hf_config, ) diff --git a/lmdeploy/pytorch/configurations/qwen3_next.py b/lmdeploy/pytorch/configurations/qwen3_next.py new file mode 100644 index 0000000000..e5bf81e80d --- /dev/null +++ b/lmdeploy/pytorch/configurations/qwen3_next.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +def _check_env_qwen3_next(device: str): + """Check env for qwen3 next.""" + if device != 'cuda': + return + + # check cuda + try: + import causal_conv1d # noqa: F401 + except ImportError: + raise ImportError('Qwen3-Next cuda support requires https://github.com/Dao-AILab/causal-conv1d.') + + try: + import fla # noqa: F401 + except ImportError: + raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.') + + +class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type == 'qwen3_next' + + @classmethod + def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs) + + # update num layers + num_layers = cfg.num_layers + num_full_layers = num_layers // hf_config.full_attention_interval + num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1) + cfg.num_layers = num_full_layers + + # set state shapes + head_k_dim = hf_config.linear_key_head_dim + head_v_dim = hf_config.linear_value_head_dim + num_v_heads = hf_config.linear_num_value_heads // tp + num_k_heads = hf_config.linear_num_key_heads // tp + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_dim = key_dim * 2 + value_dim + conv_kernel_size = hf_config.linear_conv_kernel_dim + + conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size) + recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim) + dtype = torch.bfloat16 + cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)] + cfg.check_env_func = _check_env_qwen3_next + return cfg diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index d8ec198349..96c0cb63c3 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm import json +import math +from dataclasses import dataclass from typing import Dict, List, Literal, Optional, Tuple import torch @@ -20,6 +22,23 @@ logger = get_logger('lmdeploy') +def round_up(x: int, alignment: int) -> int: + """Round up x to the nearest multiple of alignment.""" + return ((x + alignment - 1) // alignment) * alignment + + +@dataclass +class CacheDesc: + """Cache description.""" + shape: List[int] + dtype: torch.dtype + alignment: int = 256 + + def __post_init__(self): + self.size = math.prod(self.shape) * self.dtype.itemsize + self.aligned_size = round_up(self.size, self.alignment) + + class CacheEngine: """Host and Device memory maintainer. @@ -384,3 +403,77 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote )) """ Metheds for PD Disaggregation End. """ + + +class StateCacheEngine: + """Cache engine for state cache.""" + + def __init__(self, cache_config: CacheConfig): + self.cache_config = cache_config + self.mem_pool, self._state_caches = self.allocate_caches(num_caches=cache_config.num_state_caches, + state_shapes=cache_config.states_shapes, + device='cuda') + + @staticmethod + def allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device): + """Allocate cache implement.""" + + if len(state_shapes) == 0 or num_caches == 0: + return torch.empty((0, 0), dtype=torch.uint8, device=device), [] + + cache_descs = [CacheDesc(shape, dtype) for shape, dtype in state_shapes] + + # get mempool size + mem_pool_size = 0 + for desc in cache_descs: + mem_pool_size += desc.aligned_size + + # create pool + mem_pool = torch.zeros((num_caches, mem_pool_size), dtype=torch.uint8, device=device) + + # slice caches + caches = [] + remain_pool = mem_pool + for desc in cache_descs: + cache = remain_pool[:, :desc.size].view(desc.dtype).view((num_caches, *desc.shape)) + remain_pool = remain_pool[:, desc.aligned_size:] + caches.append(cache) + return mem_pool, caches + + @staticmethod + def get_cache_state_size(state_shapes: List[Tuple[Tuple[int], torch.dtype]]) -> int: + """Get the required cache size of the state cache. + + Args: + state_shapes (List[Tuple[Tuple[int], torch.dtype]]): The shapes and dtypes of the states. + + Return: + int: Required memory size in bytes. + """ + mem_pool, _ = StateCacheEngine.allocate_caches(num_caches=1, state_shapes=state_shapes, device='meta') + return mem_pool.numel() * mem_pool.element_size() + + @property + def state_caches(self): + """State caches.""" + return self._state_caches + + def init_caches(self, idx: torch.Tensor, mask: torch.Tensor): + """Initialize state caches. + + idx: indices of caches to be initialized. + mask: mask to indicate which idx to be initialized. + """ + if idx is None: + return + + if len(self._state_caches) <= 0: + return + + num_caches = self.cache_config.num_state_caches + + # get mask of all caches so we can perform inplace mask fill + cache_masks = torch.zeros((num_caches, ), dtype=torch.bool, device=idx.device) + cache_masks.index_copy_(0, idx, mask) + reshaped_mask = cache_masks.view((-1, ) + (1, ) * (self.mem_pool.dim() - 1)) + self.mem_pool.masked_fill_(reshaped_mask, 0) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 26badf256e..2b65230c9d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -807,6 +807,11 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) model_inputs.vision_inputs = vision_model_inputs + # ssm + if len(self.cache_config.states_shapes) > 0: + state_offsets = torch.tensor([msg.logical_state for msg in messages]) + model_inputs.state_offsets = state_offsets + return model_inputs def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e50843a80..ba8333520b 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -135,7 +135,7 @@ def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_ # estimate runtime mem size runtime_cache_size = int((max_prefill_token_num + max_batches * 2) * vocal_size * 2) num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count - if int(num_available) // cache_block_size >= 16: + if cache_block_size == 0 or int(num_available) // cache_block_size >= 16: break max_prefill_token_num = max_prefill_token_num // 2 return runtime_cache_size, max_prefill_token_num @@ -153,16 +153,48 @@ def _adjust_block_size(self): f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.' # noqa ) + def _get_state_cache_mem(self): + """Get state cache mem usage.""" + cache_config = self.cache_config + if len(cache_config.states_shapes) == 0: + return 0 + + from lmdeploy.pytorch.engine.cache_engine import StateCacheEngine + + num_state_caches = cache_config.num_state_caches + if num_state_caches is None: + # add more caches for eviction + # TODO: Share memory between state cache and pageable cache + num_state_caches = int(cache_config.max_batches + 8) + cache_config.num_state_caches = num_state_caches + + mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes) + mems *= num_state_caches + + if cache_config.enable_prefix_caching: + cache_config.enable_prefix_caching = False + logger.warning('Prefix caching has not been support for state space model.') + + return mems + def update_configs(self): """Update cache config.""" self._adjust_block_size() cache_config = self.cache_config model_config = self.model_config + cache_config.states_shapes = model_config.states_shapes + + # get free mems free_mems = self.gather_free_mem() free_mem = min(free_mems) logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb') - vocal_size = self.model_config.vocab_size + # get state cache size + state_cache_mem = self._get_state_cache_mem() + free_mem = free_mem - state_cache_mem + assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.' + + vocal_size = self.model_config.vocab_size tp = self.dist_config.attn_config.tp cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp, cache_config.quant_policy) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 5dd6b561cb..f1df38e43c 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -29,7 +29,7 @@ from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights -from .cache_engine import CacheEngine +from .cache_engine import CacheEngine, StateCacheEngine from .guided_process import GuidedDecodingManager from .logits_process import FusedLogitsProcessor, SamplingInputs @@ -222,6 +222,7 @@ def model_forward( model: torch.nn.Module, inputs: ModelInputs, cache_engine: CacheEngine, + state_cache_engine: StateCacheEngine, stream: torch.cuda.Stream = None, ): """Perform model forward.""" @@ -233,6 +234,7 @@ def model_forward( inputs=inputs, model_config=cache_engine.model_config, kv_caches=cache_engine.gpu_cache, + state_caches=state_cache_engine.state_caches, kv_quant_policy=cache_engine.cache_config.quant_policy, ) with ctx_mgr.context(context): @@ -349,6 +351,7 @@ def __init__(self, self.patched_model = None self.cache_engine = None + self.state_cache_engine = None self.profiler: AgentProfiler = None try: self.guided_decoding_manager = GuidedDecodingManager(self.tokenizer, model_config.vocab_size) @@ -394,7 +397,10 @@ def get_free_mem(self): def warmup(self): """warmup.""" - # TODO: disable for now, do not remove the comments. + from lmdeploy.pytorch.envs import skip_warmup + if skip_warmup: + return + with self.all_context(): max_batches = self.cache_config.max_batches num_tokens = max_batches @@ -689,6 +695,10 @@ async def __prepare_dp(): logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') return + if not is_decoding: + # init state cache for first time prefill + # I don't know if this is necessary... + self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0) cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) for idx in range(loop_count): # inference @@ -968,12 +978,14 @@ def build_cache_engine(self): tp_rank=self.tp_rank, world_size=tp, cache_stream=self.cache_stream) + self.state_cache_engine = StateCacheEngine(self.cache_config) def _forward_impl(self, inputs: ModelInputs): output = model_forward( self.patched_model, inputs, self.cache_engine, + state_cache_engine=self.state_cache_engine, stream=self.stream, ) return output diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index d5d08362d9..f36aeced3e 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -131,6 +131,9 @@ def _patched_get_env( os.getenv('DG_JIT_DEBUG', '0') os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', '0') + # model agent + skip_warmup = env_to_bool('LMD_SKIP_WARMUP', False) + def get_all_envs(): """Get all environment variables.""" diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index a4941c36af..0c61e77d4a 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -506,19 +506,27 @@ def _kernel_meta_sm8x(BLOCK_DMODEL: int, BLOCK_H: int): def _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int): """Kernel meta default.""" - return _kernel_meta_default(BLOCK_DMODEL, BLOCK_H) + num_warps = 4 + if BLOCK_DMODEL * BLOCK_H > 4096: + num_stages = 2 + else: + num_stages = 3 + return num_warps, num_stages -def _get_split_k(device_idx: int, head_grid: int, batch_size: int): +def _get_split_k(device_idx: int, head_grid: int, batch_size: int, num_warps: int): """Get split k.""" props = get_device_props(device_idx) num_sm = props['multi_processor_count'] # estimated occupancy 12.5% warps_per_sm = props['warps_per_sm'] // 8 + cta_per_sm = triton.cdiv(warps_per_sm, num_warps) + cta_per_device = num_sm * cta_per_sm - SPLIT_K = triton.cdiv(num_sm * warps_per_sm // head_grid, triton.next_power_of_2(batch_size)) + SPLIT_K = triton.cdiv(cta_per_device // head_grid, triton.next_power_of_2(batch_size)) SPLIT_K = 1 << (SPLIT_K.bit_length() - 1) - SPLIT_K = max(min(SPLIT_K, 64), 4) + max_split = 1 << (num_sm.bit_length() - 1) + SPLIT_K = max(min(SPLIT_K, max_split), 4) return SPLIT_K @@ -616,7 +624,14 @@ def _get_block_d(Lk): TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H) grid_1 = TILES_PER_GROUP * num_kv_heads - SPLIT_K = _get_split_k(q.device.index, grid_1, batch) + if _nv_cap[0] < 8: + num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H) + elif _nv_cap[0] < 9: + num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H) + else: + num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H) + + SPLIT_K = _get_split_k(q.device.index, grid_1, batch, num_warps) if quant_policy != 4: acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32) @@ -629,13 +644,6 @@ def _get_block_d(Lk): batch, ) - if _nv_cap[0] < 8: - num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H) - elif _nv_cap[0] < 9: - num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H) - else: - num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H) - if quant_policy > 0: _fwd_grouped_split_quant_kernel[grid](q, k, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 101ed62546..ec5e6098b8 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -483,6 +483,7 @@ class SchedulerSequence: num_new_tokens: int = 0 sampling_param: SamplingParam = field(default_factory=SamplingParam) logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks) + logical_state: int = -1 adapter_name: str = None arrive_time: float = 0.0 output_start_pos: int = 0 diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 13e35fd1ae..7b3a09de3d 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -144,6 +144,7 @@ class ModelInputs: model_metas: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None enable_microbatch: bool = False + state_offsets: torch.LongTensor = None def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): """Update input ids.""" @@ -257,6 +258,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int): model_metas=self.model_metas, cross_length=cross_length, history_cross_length=history_cross_length, + state_offsets=self.state_offsets, ) ret.append(inp) history_cross_length = cross_length @@ -322,6 +324,10 @@ class StepContext: dp_meta: DPMeta = None enable_microbatch: bool = False + # states for ssm + state_caches: List = None + state_offsets: torch.LongTensor = None + _outputs: Dict = field(default_factory=dict) @classmethod @@ -330,6 +336,7 @@ def new( inputs: ModelInputs, model_config: ModelConfig, kv_caches: List = None, + state_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build step context. @@ -389,6 +396,8 @@ def new( cross_kv_seqlens=cross_kv_seqlens, dp_meta=inputs.dp_meta, enable_microbatch=inputs.enable_microbatch, + state_caches=state_caches, + state_offsets=inputs.state_offsets, ) ret = get_backend().update_step_context(ret) @@ -454,6 +463,7 @@ def build_context( inputs: ModelInputs, model_config: ModelConfig, kv_caches: List = None, + state_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build context.""" @@ -461,6 +471,7 @@ def build_context( inputs, model_config, kv_caches, + state_caches, kv_quant_policy, ) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 498e2c6554..181418adf7 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -225,6 +225,11 @@ 'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM', }) +# qwen3 next model +MODULE_MAP.update({ + 'Qwen3NextForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_next.Qwen3NextForCausalLM', +}) + # SDAR MODULE_MAP.update({ 'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM', diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py new file mode 100644 index 0000000000..128afe72b4 --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_next.py @@ -0,0 +1,1085 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe +from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight + +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin + + +class GatedDeltaMeta: + + def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tensor, attn_metadata: Any): + self.num_tokens = num_tokens + self.is_decoding = attn_metadata.is_decoding + self.cu_seqlens = attn_metadata.cu_seqlens_q + device = self.cu_seqlens.device + + # get seq_idx (1, num_tokens) + seqlens = attn_metadata.q_seqlens + batch_size = seqlens.numel() + batch_idx = torch.arange(0, batch_size, dtype=torch.int32, device=device) + self.seq_idx = torch.repeat_interleave(batch_idx, seqlens, output_size=num_tokens)[None] + + # conv_idx + range_idx = torch.arange(-conv_kernel_size, 0, device=device) + self.conv_idx = self.cu_seqlens[1:, None] + range_idx[None] + self.conv_idx = self.conv_idx.clamp_min(0) + + # state_ids, fill invalid state with state_ids[0] + self.valid_state = state_ids >= 0 + self.state_ids = torch.where(self.valid_state, state_ids, state_ids[0]) + self.state_ids = self.state_ids.clamp(0) + + +class CausalConv1dFunc: + + def __init__(self, activation: str = 'silu'): + try: + import causal_conv1d + self.causal_conv1d_fn = causal_conv1d.causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d.causal_conv1d_update + except Exception: + raise RuntimeError( + 'causal_conv1d is not installed, please refer to https://github.com/Dao-AILab/causal-conv1d') + self.activation = activation + + def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor, + gated_delta_meta: GatedDeltaMeta): + """ + x: (b, seqlen, dim) + seqlen: (b) + out: (b, seqlen, dim) + conv_state: (b, dim, kernel_size) + """ + seq_idx = gated_delta_meta.seq_idx + conv_idx = gated_delta_meta.conv_idx + + assert x.dim() == 3 + x = x.transpose(-2, -1) + if weight.dim() == 3: + assert weight.size(1) == 1 + weight = weight[:, 0] + + # fill conv state + batch_size = conv_state.size(0) + conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1) + torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=conv_state) + + out = self.causal_conv1d_fn( + x, + weight, + bias, + seq_idx, + return_final_states=False, + activation=self.activation, + ) + + out = out.transpose(-2, -1) + + # store conv_state + return out, conv_state + + def conv1d_update(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor): + if weight.dim() == 3: + assert weight.size(1) == 1 + weight = weight[:, 0] + out = self.causal_conv1d_update(x[0], conv_state, weight, bias, activation=self.activation) + return out[None], conv_state + + def __call__( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + conv_state: torch.Tensor, + gated_delta_meta: GatedDeltaMeta, + ): + if gated_delta_meta.is_decoding: + return self.conv1d_update(x, weight, bias, conv_state) + return self.conv1d_func(x, weight, bias, conv_state, gated_delta_meta=gated_delta_meta) + + +class GatedDelta: + + def __init__(self, use_qk_l2norm_in_kernel: bool = True): + try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + self.chunk_gated_delta_rule = chunk_gated_delta_rule + self.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule + except Exception: + raise RuntimeError( + 'fla is not installed, please refer to https://github.com/fla-org/flash-linear-attention') + self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_state: torch.Tensor, + gated_delta_meta: GatedDeltaMeta, + ): + """call.""" + is_decoding = gated_delta_meta.is_decoding + cu_seqlens = gated_delta_meta.cu_seqlens + + if not is_decoding: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + ) + else: + # qkvgb (1, seqlen, ...) -> (seqlen, 1, ...) + core_attn_out, last_recurrent_state = self.fused_recurrent_gated_delta_rule( + query[0, :, None], + key[0, :, None], + value[0, :, None], + g=g[0, :, None], + beta=beta[0, :, None], + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, + ) + # out (seqlen, 1, ...) -> (1, seqlen, ...) + core_attn_out = core_attn_out[None, :, 0] + return core_attn_out, last_recurrent_state + + +def build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs): + from fla.modules import FusedRMSNormGated + return FusedRMSNormGated(hidden_size, eps=eps, **kwargs) + + +class CausalConv1d(nn.Module): + """Causal conv1d wrapper.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + groups: int = 1, + bias: bool = True, + split=None, + device=None, + dtype=None, + ): + super().__init__() + tp, rank = get_tp_world_rank() + self.tp = tp + self.rank = rank + in_channels = in_channels // tp + out_channels = out_channels // tp + groups = groups // tp + assert len(split) == 3 + self.split = split + + weight, bias = self.make_weight( + in_channels, + out_channels, + kernel_size=kernel_size, + groups=groups, + bias=bias, + device=device, + dtype=dtype, + ) + + self.register_weight(weight, bias) + self.causal_conv1d_func = CausalConv1dFunc(activation='silu') + + @staticmethod + def make_weight( + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + groups: int = 1, + bias: bool = True, + device=None, + dtype=None, + ): + weight_shape = (out_channels, in_channels // groups, + kernel_size if isinstance(kernel_size, int) else kernel_size[0]) + bias_shape = (out_channels, ) if bias else None + + weight = torch.empty(weight_shape, device=device, dtype=dtype) + if bias_shape is not None: + bias = torch.empty(bias_shape, device=device, dtype=dtype) + else: + bias = None + return weight, bias + + def register_weight(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + self.register_parameter('weight', nn.Parameter(weight)) + self.weight.weight_loader = self.weight_loader + if bias is not None: + self.register_parameter('bias', nn.Parameter(bias)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter('bias', None) + + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader.""" + q, k, v = loaded_weight.split(self.split, dim=0) + q = q.chunk(self.tp, dim=0)[self.rank] + k = k.chunk(self.tp, dim=0)[self.rank] + v = v.chunk(self.tp, dim=0)[self.rank] + loaded_weight = torch.cat([q, k, v], dim=0) + default_weight_loader(param, loaded_weight) + + def forward(self, x: torch.Tensor, conv_state: torch.Tensor, gated_delta_meta: GatedDeltaMeta): + """forward.""" + return self.causal_conv1d_func(x, self.weight, self.bias, conv_state, gated_delta_meta=gated_delta_meta) + + +class Qwen3NextGatedDeltaNet(nn.Module): + """Gated deltanet.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.kv_ratio = self.num_v_heads // self.num_k_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = CausalConv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + split=[self.key_dim, self.key_dim, self.value_dim], + dtype=dtype, + device=device, + ) + + # projection of the input hidden states + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = build_colwise_linear(self.hidden_size, + projection_size_qkvz, + bias=False, + dtype=dtype, + device=device, + is_tp=True) + # dirty patch to qkvz + self.in_proj_qkvz.weight.weight_loader = self.weight_loader_qkvz + self.in_proj_ba = build_colwise_linear(self.hidden_size, + projection_size_ba, + bias=False, + dtype=dtype, + device=device, + is_tp=True) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.make_params(self.num_v_heads, device=device) + self.A_log_exp = None + + self.norm = build_rmsnorm_gated(self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + dtype=dtype, + device=device) + self.out_proj = build_o_proj(self.value_dim, + self.hidden_size, + bias=False, + dtype=dtype, + device=device, + is_tp=True) + + self.gated_delta = GatedDelta() + + def get_A_log_exp(self): + if self.A_log_exp is None: + self.A_log_exp = -self.A_log.float().exp() + + return self.A_log_exp + + def make_params(self, num_v_heads: int, device: torch.device): + tp, _ = get_tp_world_rank() + num_v_heads = num_v_heads // tp + A = torch.empty(num_v_heads, device=device).uniform_(0, 16) + dt_bias = torch.empty(num_v_heads, device=device).uniform_(0, 1) + + self.register_parameter('A_log', nn.Parameter(torch.log(A))) + self.register_parameter('dt_bias', nn.Parameter(dt_bias)) + self.A_log.weight_loader = self.weight_loader_a_dt + self.dt_bias.weight_loader = self.weight_loader_a_dt + + def weight_loader_qkvz(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader qkvz.""" + tp, rank = get_tp_world_rank() + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.kv_ratio * self.head_v_dim), + (self.kv_ratio * self.head_v_dim), + ] + sum_split = sum(split_arg_list_qkvz) + loaded_weight = loaded_weight.unflatten(0, (-1, sum_split)) + q, k, v, z = loaded_weight.split(split_arg_list_qkvz, dim=1) + q = q.chunk(tp, dim=0)[rank] + k = k.chunk(tp, dim=0)[rank] + v = v.chunk(tp, dim=0)[rank] + z = z.chunk(tp, dim=0)[rank] + + loaded_weight = torch.cat([q, k, v, z], dim=1) + loaded_weight = loaded_weight.flatten(0, 1) + default_weight_loader(param, loaded_weight) + + def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader.""" + tp, rank = get_tp_world_rank() + loaded_weight = loaded_weight.chunk(tp, dim=0)[rank] + default_weight_loader(param, loaded_weight) + + def fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor): + """Derives `query`, `key` and `value` tensors from `mixed_qkvz` and + `mixed_ba`.""" + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + -1, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.kv_ratio, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + (-1, 2 * self.kv_ratio) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.kv_ratio * self.head_v_dim), + (self.kv_ratio * self.head_v_dim), + ] + split_arg_list_ba = [self.kv_ratio, self.kv_ratio] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=-1) + # [..., ng, np/ng * hn] -> [..., np, hn] + value = value.reshape(*value.shape[:-2], -1, self.head_v_dim) + z = z.reshape(*z.shape[:-2], -1, self.head_v_dim) + b = b.reshape(*b.shape[:-2], -1) + a = a.reshape(*a.shape[:-2], -1) + return query, key, value, z, b, a + + def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta): + """Load states from cache.""" + state_ids = gated_delta_meta.state_ids + conv_cache, recurrent_cache = past_key_value[:2] + + return conv_cache.index_select(0, state_ids), recurrent_cache.index_select(0, state_ids) + + def _store_state(self, conv_state: torch.Tensor, recurrent_state: torch.Tensor, + past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta): + """Store states to cache.""" + conv_cache, recurrent_cache = past_key_value[:2] + state_ids = gated_delta_meta.state_ids + valid_state = gated_delta_meta.valid_state + + # fill invalid state with state[0] + conv_dim = conv_state.dim() + recurrent_dim = recurrent_state.dim() + conv_state = torch.where(valid_state.view(-1, *[1] * (conv_dim - 1)), conv_state, conv_state[:1]) + recurrent_state = torch.where(valid_state.view(-1, *[1] * (recurrent_dim - 1)), recurrent_state, + recurrent_state[:1]) + + conv_cache = conv_cache.index_copy_(0, state_ids, conv_state) + recurrent_cache = recurrent_cache.index_copy_(0, state_ids, recurrent_state.to(recurrent_cache.dtype)) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Tuple[torch.Tensor, torch.Tensor], + gated_delta_meta: GatedDeltaMeta, + ): + """forward.""" + + # load states + conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta) + + # inputs proj + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(*x.shape[:-2], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta) + + tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // tp, + self.key_dim // tp, + self.value_dim // tp, + ], + dim=-1, + ) + query = query.unflatten(-1, (-1, self.head_k_dim)) + key = key.unflatten(-1, (-1, self.head_k_dim)) + value = value.unflatten(-1, (-1, self.head_v_dim)) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = self.get_A_log_exp() * F.softplus(a.float() + self.dt_bias) + if self.kv_ratio > 1: + query = query.repeat_interleave(self.kv_ratio, dim=-2) + key = key.repeat_interleave(self.kv_ratio, dim=-2) + + core_attn_out, recurrent_state = self.gated_delta( + query, + key, + value, + g=g, + beta=beta, + recurrent_state=recurrent_state, + gated_delta_meta=gated_delta_meta, + ) + + # store states + self._store_state(conv_state, recurrent_state, past_key_value, gated_delta_meta) + + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3NextAttention(nn.Module): + """Rewrite module of Qwen3MoeAttention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + self.head_dim = head_dim + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + + # packed qkv + # Qwen3 uses 'config.attention_bias = False' for q/k/o projections + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads * 2, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + num_replicate_kv_heads=num_replicate_kv_heads, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + ) + + # o_proj + self.o_proj = build_o_proj(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + # q, k norm + self.q_norm = RMSNorm(head_dim, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.k_norm = RMSNorm(head_dim, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states) + query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1) + + # apply q, k norm + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + gate = gate.reshape(*hidden_states.shape[:-1], -1) + attn_output = attn_output * gate.sigmoid() + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen3NextMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + intermediate_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = True, + all_reduce: bool = True): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + if intermediate_size is None: + intermediate_size = config.intermediate_size + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=is_tp, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=is_tp, + all_reduce=all_reduce) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class Qwen3NextSparseMoeBlock(nn.Module): + """Moe block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + # TODO: zhouxinyu, determine modules_to_not_convert from config file + quantization_config = getattr(config, 'quantization_config', None) + self.layer_idx = layer_idx + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.renormalize = self.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + self.experts = build_fused_moe( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + quant_config=quantization_config, + all_reduce=False, + layer_idx=layer_idx, + ) + + self.shared_expert = Qwen3NextMLP( + config=config, + intermediate_size=config.shared_expert_intermediate_size, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False, + ) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype) + + # get all reduce + dist_ctx = get_dist_manager().current_context() + dp = dist_ctx.dp + world_size = dist_ctx.world_size + if dp == 1 and world_size > 1: + self._all_reduce = True + else: + self._all_reduce = False + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + topk_weights, topk_ids = self.softmax_topk(router_logits) + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + shared_states = self.shared_expert(hidden_states) + shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states + + out_states += shared_states + out_states = out_states.reshape(batch_size, sequence_length, -1) + + if self._all_reduce: + dist.all_reduce(out_states) + return out_states + + +class Qwen3NextDecoderLayer(nn.Module): + """Decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == 'linear_attention': + self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx, dtype=dtype, device=device) + elif self.layer_type == 'full_attention': + self.self_attn = Qwen3NextAttention(config, dtype=dtype, device=device) + + # build MLP + if (layer_idx not in config.mlp_only_layers) and (config.num_experts + > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device) + else: + self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor], + attn_metadata: Any, + gated_delta_meta: GatedDeltaMeta, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + if self.layer_type == 'linear_attention': + hidden_states = self.linear_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + gated_delta_meta=gated_delta_meta, + ) + elif self.layer_type == 'full_attention': + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class Qwen3NextModel(nn.Module): + """Qwen3 next model.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + # TODO: use full config.num_hidden_layers + self.layers = nn.ModuleList([ + Qwen3NextDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(self.config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + past_key_values: List[torch.FloatTensor], + attn_metadata: Any, + state_ids: torch.Tensor, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # make seq_idx + gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids, + attn_metadata) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_values[idx], + residual=residual, + attn_metadata=attn_metadata, + gated_delta_meta=gated_delta_meta, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class Qwen3NextForCausalLM(nn.Module, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Qwen3NextModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + state_ids: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + state_ids=state_ids, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # make past_key_values + state_caches = list(cache.transpose(0, 1) for cache in context.state_caches) + state_caches = list(zip(state_caches[0], state_caches[1])) + past_key_values = list(past_key_values) + new_past_key_values = [] + for layer_type in self.config.layer_types: + if layer_type == 'linear_attention': + new_past_key_values.append(state_caches.pop(0)) + elif layer_type == 'full_attention': + new_past_key_values.append(past_key_values.pop(0)) + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=new_past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + state_ids=context.state_offsets, + ) + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_batchs = graph_meta.max_batchs + device = graph_meta.device + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device) + input_buffers['state_ids'] = state_ids + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + input_buffers = graph_meta.input_buffers + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + state_ids = kwargs['state_ids'] + input_buffers['state_ids'].fill_(-1) + input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) + new_inputs['state_ids'] = input_buffers['state_ids'] + + return new_inputs + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + # load fused weights + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + + def __skip_layers(name): + """We might change the number of layers so we can debug the model + with less gpus.""" + import re + if '.layers.' not in name: + return False + matches = re.findall(r'\.layers\.(\d+)\.', name) + layer_id = int(matches[0]) + return layer_id >= self.config.num_hidden_layers + + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert map + num_experts = self.config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm'] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + + if __skip_layers(name): + continue + + if 'mtp.' in name: + continue + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + name = name.replace('.block_sparse_moe.', '.mlp.') + if '.experts' in name and '.shared_expert' not in name: + self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + for rms_norm_key in rms_norm_keys: + if rms_norm_key in name and 'weight' in name: + loaded_weight = loaded_weight + 1 + break + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 065aef97d1..5137f8f132 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -82,6 +82,10 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> # create buffer for cross_attn_metadata here input_buffers['fill_seqlens'] = torch.zeros(max_batches, dtype=torch.int64, device=device) + input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device) + input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0] + input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1] + return input_buffers def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor, @@ -107,7 +111,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) input_buffers['qkv_lens'].zero_() + input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs) input_buffers['qkv_lens'][:, :batch_size] = qkv + input_buffers['cu_seqlens_q'][1:batch_size + 1] = input_buffers['q_seqlens'][:batch_size].cumsum(0) + input_buffers['cu_seqlens_k'][1:batch_size + 1] = input_buffers['kv_seqlens'][:batch_size].cumsum(0) if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if 'inputs_embeds' not in input_buffers: @@ -121,6 +128,8 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p attn_metadata.q_start_loc = input_buffers['q_start_loc'] attn_metadata.q_seqlens = input_buffers['q_seqlens'] attn_metadata.kv_seqlens = input_buffers['kv_seqlens'] + attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q'] + attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k'] if getattr(self.config, 'use_flash_mla', False) is True: import flash_mla tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32), diff --git a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py index 20ef018824..3799d60a42 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py +++ b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py @@ -14,6 +14,8 @@ def __init__(self, scheduler: Scheduler): self.scheduler = scheduler self.block_manager = scheduler.block_manager self.block_trie = scheduler.block_trie + self.state_manager = scheduler.state_manager + self.cache_config = scheduler.cache_config def need_swap_in(self, seq: SchedulerSequence): """Sequence need swap in.""" diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py index 7dac755666..be0d09a5f9 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py +++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py @@ -2,13 +2,23 @@ from typing import List from ...messages import SchedulerSequence +from ..scheduler import Scheduler from .base_eviction_helper import BaseEvictionHelper class RecomputeEvictionHelper(BaseEvictionHelper): """Recompute eviction.""" - def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int): + def __init__(self, scheduler: Scheduler): + super().__init__(scheduler) + + if len(self.cache_config.states_shapes) == 0: + self.evict_for_seq = self._evict_for_seq_default + else: + self.evict_for_seq = self._evict_for_ssm + + def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], + prealloc_size: int): """Evict seqs.""" block_manager = self.block_manager block_trie = self.block_trie @@ -46,3 +56,51 @@ def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSe success = True return success + + def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int): + """Evict seqs.""" + block_manager = self.block_manager + state_manager = self.state_manager + block_trie = self.block_trie + num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size) + has_free_state = state_manager.get_num_free() > 0 + + if has_free_state and block_manager.get_num_free_gpu_blocks() >= num_required_blocks: + return True + + success = False + while len(evictable_seqs) > 0: + evict_seq = evictable_seqs.pop(0) + + # skip sequence with no blocks + if evict_seq.num_blocks == 0 and evict_seq.logical_state < 0: + continue + + # free sequence + block_manager.free(evict_seq) + evict_seq.set_step(0) + state_manager.free(evict_seq) + has_free_state = True + num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) + if num_req <= 0: + success = True + break + + # clear cached prefix + block_trie.evict(num_req) + num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) + if num_req <= 0: + success = True + break + + if not has_free_state: + return False + + # for empty evictable_seqs case + num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks() + if num_req > 0: + block_trie.evict(num_req) + if num_required_blocks <= block_manager.get_num_free_gpu_blocks(): + success = True + + return success diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index e19bd18141..838688829f 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -13,6 +13,7 @@ from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta from .block_manager import build_block_manager from .block_trie import BlockTrie +from .state_manager import StateManager logger = get_logger('lmdeploy') @@ -51,6 +52,8 @@ def __init__(self, self.block_manager = build_block_manager(cache_config) self.block_trie = BlockTrie(self.cache_config, self.block_manager) + self.state_manager = StateManager(self.cache_config.num_state_caches) + self.is_ssm = len(self.cache_config.states_shapes) > 0 self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) @@ -240,6 +243,8 @@ def _reorder_waiting(): # allocate session memory self.block_manager.allocate(seq, prealloc_size) + if self.is_ssm: + self.state_manager.allocate(seq) _to_running(seq) seq.record_event(EventType.SCHEDULED) @@ -335,6 +340,7 @@ def _remove_sequence(self, seq: SchedulerSequence): seq (SchedulerSequence): sequence to remove """ self.block_manager.free(seq) + self.state_manager.free(seq) seq.set_step(0) seq.session.remove_sequence(seq) diff --git a/lmdeploy/pytorch/paging/state_manager.py b/lmdeploy/pytorch/paging/state_manager.py new file mode 100644 index 0000000000..6de882111b --- /dev/null +++ b/lmdeploy/pytorch/paging/state_manager.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from lmdeploy.pytorch.messages import SchedulerSequence + + +class StateAllocator: + """State allocator.""" + + def __init__(self, num_states: int): + self.num_states = num_states + self._free_states = np.arange(num_states, dtype=np.int64) + self._free_count = num_states + + def allocate(self): + """allocate.""" + if self.get_num_free() == 0: + raise RuntimeError('No free states.') + alloc_id = self._free_states[-self._free_count] + self._free_count -= 1 + return alloc_id + + def free(self, state_id: int): + """free.""" + if self._free_count >= self.num_states: + raise RuntimeError('All states are free.') + self._free_count += 1 + self._free_states[-self._free_count] = state_id + + def get_num_free(self): + return self._free_count + + +class StateManager: + + def __init__(self, num_states: int): + if num_states is None: + num_states = 1 + self.allocator = StateAllocator(num_states) + + def is_allocated(self, seq: SchedulerSequence): + """Check if a sequence is allocated.""" + return seq.logical_state >= 0 + + def allocate(self, seq: SchedulerSequence): + """Allocate states for a sequence.""" + if self.is_allocated(seq): + return None + seq.logical_state = self.allocator.allocate() + + def free(self, seq: SchedulerSequence): + """Free states for a sequence.""" + if not self.is_allocated(seq): + return None + self.allocator.free(seq.logical_state) + seq.logical_state = -1 + + def get_num_free(self): + """Get num free.""" + return self.allocator.get_num_free() diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py index 795220bd02..e7180bd926 100644 --- a/lmdeploy/pytorch/strategies/base/model_inputs.py +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -26,6 +26,7 @@ def make_dummy_inputs(batch_size: int, block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) + state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device) return ModelInputs( input_ids=input_ids, @@ -38,6 +39,7 @@ def make_dummy_inputs(batch_size: int, max_kv_seqlen=max_kv_seqlen, sum_kv_seqlen=num_tokens, local_adapter_ids=local_adapter_ids, + state_offsets=state_offsets, )