- 
                Notifications
    
You must be signed in to change notification settings  - Fork 614
 
support context parallel #3951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
support context parallel #3951
Changes from all commits
c1dae3a
              bb27b62
              0fe88bc
              5c02779
              53654ad
              e3dd4f7
              1f75dd6
              be504d3
              25a8fb8
              77ef52a
              29cf813
              0044d4f
              e4050a4
              f44ef96
              a329b29
              dafcd64
              c9649c0
              52766d2
              b783d5c
              c39373a
              47a349b
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Context Parallel | ||
| 
     | 
||
| When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages: | ||
| 
     | 
||
| 1. The amount of available kv_cache is halved, which reducing the maximum supported session length. | ||
| 2. The maximum inference batch size is reduced, leading to lower throughput. | ||
| 
     | 
||
| To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below: | ||
| 
     | 
||
| ``` | ||
| cp_rank=2, prompt_len=5, generation_len=4 | ||
| kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 | ||
| kv_cache stored on cp_rank1: 1, 3, 5, 7 | ||
| ``` | ||
| 
     | 
||
| ## Usage | ||
| 
     | 
||
| Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way: | ||
| 
     | 
||
| ``` | ||
| lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 | ||
| 
     | 
||
| lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 | ||
| ``` | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # 序列并行 | ||
| 
     | 
||
| 在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点: | ||
| 
     | 
||
| 1. 可用的 kvcache 数量减半,进而减少请求最大推理长度 | ||
| 2. 降低推理的最大 batch 数量,减少吞吐量。 | ||
| 
     | 
||
| 为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如 | ||
| ``` | ||
| cp_rank=2, prompt_len=5, generation_len=4 | ||
| kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 | ||
| kv_cache stored on cp_rank1: 1, 3, 5, 7 | ||
| ``` | ||
| 
     | 
||
| ## 使用说明 | ||
| 
     | 
||
| 以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署 | ||
| 
     | 
||
| ``` | ||
| lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 | ||
| 
     | 
||
| lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 | ||
| ``` | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -237,6 +237,7 @@ class TurbomindEngineConfig: | |
| dp: int = 1 | ||
| device_num: int = None | ||
| attn_tp_size: int = None | ||
| attn_cp_size: int = None | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we add   | 
||
| attn_dp_size: int = None | ||
| mlp_tp_size: int = None | ||
| mlp_dp_size: int = None | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -2,6 +2,7 @@ | |
| 
     | 
||
| #pragma once | ||
| 
     | 
||
| #include "cutlass/fast_math.h" | ||
| #include <cstdint> | ||
| #include <cuda_runtime.h> | ||
| 
     | 
||
| 
        
          
        
         | 
    @@ -23,6 +24,8 @@ struct BlockIteratorParams { | |
| int block_len; | ||
| }; | ||
| 
     | 
||
| typedef void (*cp_post_fn)(void* context, int split_cnt); | ||
| 
     | 
||
| /// TODO: Rename to attention::Param | ||
| template<typename T> | ||
| struct AttentionParams { | ||
| 
          
            
          
           | 
    @@ -79,6 +82,16 @@ struct AttentionParams { | |
| float* partial_L; | ||
| int* locks; | ||
| 
     | 
||
| // context parallel | ||
| int cp_rank{0}; | ||
| int cp_size{1}; | ||
| cutlass::FastDivmod cp_divmod{1}; | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
  | 
||
| int cp_q_offset{0}; // decode offset | ||
| float* cp_ML{nullptr}; // cp, q, h, 2 | ||
| float* cp_k_ML{nullptr}; // q, h, k, 2 | ||
| cp_post_fn cp_fn{nullptr}; | ||
| void* cp_fn_ctx{nullptr}; | ||
| 
     | 
||
| int arch; | ||
| cudaStream_t stream; | ||
| 
     | 
||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
cpinstread ofattn_cp_size?