Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def parse_args():
tb_group._group_actions.append(dtype_act)

ArgumentHelper.dp(tb_group)
ArgumentHelper.attn_cp_size(tb_group)
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand All @@ -344,6 +345,7 @@ def main():
max_batch_size=args.concurrency // args.dp,
tp=args.tp,
dp=args.dp,
attn_cp_size=args.attn_cp_size,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len,
model_format=args.model_format,
Expand Down
24 changes: 24 additions & 0 deletions docs/en/advance/context_parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Context Parallel

When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages:

1. The amount of available kv_cache is halved, which reducing the maximum supported session length.
2. The maximum inference batch size is reduced, leading to lower throughput.

To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below:

```
cp_rank=2, prompt_len=5, generation_len=4
kv_cache stored on cp_rank0: 0, 2, 4, 6, 8
kv_cache stored on cp_rank1: 1, 3, 5, 7
```

## Usage

Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way:

```
lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2

lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Documentation
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/context_parallel.md

.. toctree::
:maxdepth: 1
Expand Down
23 changes: 23 additions & 0 deletions docs/zh_cn/advance/context_parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# 序列并行

在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点:

1. 可用的 kvcache 数量减半,进而减少请求最大推理长度
2. 降低推理的最大 batch 数量,减少吞吐量。

为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如
```
cp_rank=2, prompt_len=5, generation_len=4
kv_cache stored on cp_rank0: 0, 2, 4, 6, 8
kv_cache stored on cp_rank1: 1, 3, 5, 7
```

## 使用说明

以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署

```
lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2

lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2
```
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/pytorch_multinodes.md
advance/pytorch_profiling.md
advance/metrics.md
advance/context_parallel.md

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def add_parser_chat():
ArgumentHelper.model_format(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.attn_cp_size(tb_group)

@staticmethod
def add_parser_checkenv():
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def add_parser_api_server():
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(hf_overrides)
tb_group._group_actions.append(enable_metrics)
ArgumentHelper.attn_cp_size(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand Down Expand Up @@ -232,6 +233,7 @@ def api_server(args):
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(dtype=args.dtype,
tp=args.tp,
attn_cp_size=args.attn_cp_size,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def ep(parser):
default=1,
help='expert parallelism. dp is required when pytorch engine is used.')

@staticmethod
def attn_cp_size(parser):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use cp instread of attn_cp_size?

"""Add argument attn_cp_size to parser."""

return parser.add_argument(
'--attn-cp-size',
type=int,
default=1,
help='context parallelism size in attention for turbomind backend. Should divide tp.')

@staticmethod
def dp_rank(parser):
"""Add argument dp_rank to parser."""
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class TurbomindEngineConfig:
dp: int = 1
device_num: int = None
attn_tp_size: int = None
attn_cp_size: int = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we add cp like dp? attn_cp_size can be used internally

attn_dp_size: int = None
mlp_tp_size: int = None
mlp_dp_size: int = None
Expand Down
9 changes: 6 additions & 3 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def complete_parallel_config(cfg: TurbomindEngineConfig):

def update_parallel_config(cfg: TurbomindEngineConfig):
if not complete_parallel_config(cfg):
attn_cp_size = cfg.attn_cp_size or 1
total = cfg.dp * cfg.tp
if not cfg.device_num:
count = torch.cuda.device_count()
Expand All @@ -100,11 +101,12 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
inner_tp_size = cfg.tp // mlp_tp_size
cfg.outer_dp_size = cfg.dp // attn_dp_size
cfg.attn_dp_size = attn_dp_size
cfg.attn_tp_size = inner_tp_size
cfg.attn_tp_size = inner_tp_size // attn_cp_size
cfg.attn_cp_size = attn_cp_size
cfg.mlp_dp_size = 1
cfg.mlp_tp_size = mlp_tp_size * inner_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num
cfg.devices = cfg.devices or list(range(cfg.device_num))


Expand Down Expand Up @@ -272,6 +274,7 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):

self._postprocess_config(tm_model.tm_config, engine_config)

print(yaml.safe_dump(self.config_dict))
model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/comm/nccl/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ public:

int Split(int color, int key, int group) override
{
auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);
// auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);

ncclComm_t comm{};
NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));
// NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));
NCCLCHECK(ncclCommSplit(groups_.at(group), color, key, &comm, nullptr));

int index = groups_.size();
groups_.push_back(comm);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(attention PRIVATE -O3
$<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr>)

target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass)

if (BUILD_TEST)
target_compile_options(attention PRIVATE
Expand Down
13 changes: 13 additions & 0 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "cutlass/fast_math.h"
#include <cstdint>
#include <cuda_runtime.h>

Expand All @@ -23,6 +24,8 @@ struct BlockIteratorParams {
int block_len;
};

typedef void (*cp_post_fn)(void* context, int split_cnt);

/// TODO: Rename to attention::Param
template<typename T>
struct AttentionParams {
Expand Down Expand Up @@ -79,6 +82,16 @@ struct AttentionParams {
float* partial_L;
int* locks;

// context parallel
int cp_rank{0};
int cp_size{1};
cutlass::FastDivmod cp_divmod{1};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass::FastDivmod has a copy of the divisor (casting it to int implicitly). Thus cp_size is not needed here. And cp_divmod may be renamed to cp_size for simplicity

int cp_q_offset{0}; // decode offset
float* cp_ML{nullptr}; // cp, q, h, 2
float* cp_k_ML{nullptr}; // q, h, k, 2
cp_post_fn cp_fn{nullptr};
void* cp_fn_ctx{nullptr};

int arch;
cudaStream_t stream;

Expand Down
9 changes: 7 additions & 2 deletions src/turbomind/kernels/attention/attention_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void invokeAttention(const typename Kernel::ParamType& params)
return int2{sm_count, max_active_ctas};
}();

const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S);
const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size;
const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);
const int max_split_count = std::min(params.max_split_k, tile_count);

typename Kernel::CtaMap cta_map{
Expand Down Expand Up @@ -80,7 +81,11 @@ void invokeAttention(const typename Kernel::ParamType& params)
std::abort();
}

if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) {
if (params.cp_fn) {
int split_k = Kernel::need_separate_reduce(split_cnt) ? split_cnt : 1;
params.cp_fn(params.cp_fn_ctx, split_k);
}
else if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) {
attention::invokeReduce<Kernel::kHeadDim>(params.out,
params.partial_M,
params.partial_L,
Expand Down
71 changes: 60 additions & 11 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ struct AttentionUniversal {
const int qi = offset.y / CTA_H;
const int ti = history_len;

int cp_quo, cp_rem;
params.cp_divmod(cp_quo, cp_rem, ti);

Array<T, 2> param_K[1];
Array<T, 2> param_V[1];

Expand All @@ -276,7 +279,10 @@ struct AttentionUniversal {
}

iterator.block_head_.with(
iterator.block_ptrs_, ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {
iterator.block_ptrs_, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {
if (cp_rem != params.cp_rank) {
return;
}
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
Expand Down Expand Up @@ -371,11 +377,18 @@ struct AttentionUniversal {
const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx];
const int history_len = context_len - input_len;

const int last_K = history_len + min(query_idx + CTA_Q, input_len);
const int last_K_tile = (last_K - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion
auto get_cp_len = [&](int length) -> int {
int cp_quo, cp_rem;
params.cp_divmod(cp_quo, cp_rem, length);
return (cp_quo + (cp_rem > params.cp_rank ? 1 : 0));
};

const int last_K = history_len + min(query_idx + CTA_Q, input_len);
const int last_K_tile =
(get_cp_len(last_K) - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion

const int first_K = max(history_len + query_idx - (params.window_size - 1), 0);
const int first_K_tile = first_K / CTA_S;
const int first_K_tile = get_cp_len(first_K) / CTA_S;

const int tile_count = last_K_tile - first_K_tile;

Expand Down Expand Up @@ -417,7 +430,7 @@ struct AttentionUniversal {
const int offset_K = (first_K_tile + iter_end - 1) * CTA_S;

// This is for avoiding OOB access only
const int max_K = min(context_len, (first_K_tile + iter_end) * CTA_S);
const int max_K = min(get_cp_len(context_len), (first_K_tile + iter_end) * CTA_S);

int tile_iter = iter_end - iter_begin;

Expand All @@ -430,6 +443,15 @@ struct AttentionUniversal {
// -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w
int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S);

if (params.cp_size > 1) {
mask_iter_back =
cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S);
mask_iter_front = cdiv(max(0,
offset_Q + CTA_Q - params.window_size - params.cp_rank
- params.cp_size * (offset_K - tile_iter * CTA_S)),
params.cp_size * CTA_S);
}

#if 0
if (threadIdx.x == 0) {
printf(
Expand All @@ -453,6 +475,7 @@ struct AttentionUniversal {
cache_iter.SetTile(first_K_tile + iter_end - 1);

Mainloop mainloop;
mainloop.SetCpInfo(params.cp_size, params.cp_rank);
mainloop(frag_Q,
cache_iter,
frag_O,
Expand Down Expand Up @@ -491,12 +514,12 @@ struct AttentionUniversal {
}
}

if (iter_begin == 0 && iter_end == tile_count) {
if (iter_begin == 0 && iter_end == tile_count && params.cp_size == 1) {
StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage);
}
else {
StorePartial(frag_O, frag_M, frag_L, qi_begin, qi_end, head_idx, split_idx, params, storage);
if (!separate_reduce) {
StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage);
if (!separate_reduce && split_cnt > 1) {
Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf);
}
}
Expand Down Expand Up @@ -527,6 +550,9 @@ struct AttentionUniversal {
params.partial_M,
params.partial_L,
params.partial_O,
params.cp_ML,
params.cp_k_ML,
params.cp_q_offset,
qi_begin,
head_idx,
params.num_heads,
Expand Down Expand Up @@ -583,6 +609,7 @@ struct AttentionUniversal {
__device__ void StorePartial(FragO& frag_O,
FragM& frag_M,
FragL& frag_L,
int split_cnt,
int qi_begin,
int qi_end,
int head_idx,
Expand All @@ -598,15 +625,37 @@ struct AttentionUniversal {

Impl::StoreO<false>(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {
if (qi_begin + qi < qi_end && check_h(hi)) {
Store(&params.partial_O[get_index(hi, qi) * kHeadDim + di], vec);
if (split_cnt > 1) { // decode
Store(&params.partial_O[get_index(hi, qi) * kHeadDim + di], vec);
}
if (params.cp_size > 1 && split_cnt == 1) {
const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di;
Store(&params.out[index], cast<T>(vec));
}
}
});

Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) {
const int index = get_index(hi, qi);
if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) {
params.partial_M[index] = M;
params.partial_L[index] = L;
if (split_cnt > 1) { // decode
params.partial_M[index] = M;
params.partial_L[index] = L;
}

auto save_cp_stats = [&](int max_split_k, int split_idx, float* ml, float M, float L) {
const int q = qi_begin + qi - params.cp_q_offset;
const int index = (q * params.num_heads + (head_idx + hi)) * max_split_k + split_idx;
ml[index * 2] = M;
ml[index * 2 + 1] = L;
};

if (params.cp_size > 1) {
if (split_cnt == 1) {
save_cp_stats(1, 0, params.cp_ML, M, L);
}
save_cp_stats(params.max_split_k, split_idx, params.cp_k_ML, M, L);
}
}
});
}
Expand Down
Loading
Loading