Skip to content

Commit e8c8e7a

Browse files
authored
PytorchEngine multi-node support v2 (#3147)
* better dist context * can not exit * multinode support * better exception * refactor * fix local rank * replace group * fix dist * remove useless code * remove finish flag * refactor engine and model agent * uni executor * wip * tp * fix * less async * circle buf * event per block * fast mp * fix error handler * remove safe wait * context in model agent * fix on stop * check before init * fix tp close * ray ver0 * fix close * fix remote code * optimize ray * better checker and logger * pack tensor * auto check dist * fix mp gloo * add timer tools * better scheduler * fix mp hang * fix mp * fix chat * less output * merge main * optimize ray get output * remove nsight runtime env * dag * optimize mp & lint * optimize mp * add base workerwrapper * fix gather, update flags * better return mask * add choice * enable mp,ray with worldsize=1 * fix mp exit * fix mp vlm * chat exit * add docs * lint * doc * dp check * fix blocked fp8 moe * remove mask * fix chat stopwords * refactor chat
1 parent 83eed6e commit e8c8e7a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2789
-1132
lines changed

benchmark/profile_throughput.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def parse_args():
190190
parser.add_argument('--use-uvloop', action='store_true')
191191
parser.add_argument('--csv', type=str, help='Where to save the result.', default='./profile_throughput.csv')
192192
parser.add_argument('--seed', type=int, default=0, help='Seed used in sampling prompts from dataset')
193+
parser.add_argument('--distributed-executor-backend',
194+
type=str,
195+
default=None,
196+
choices=['uni', 'mp', 'ray'],
197+
help='backend of executor backend')
193198
# other args
194199
ArgumentHelper.top_p(parser)
195200
ArgumentHelper.temperature(parser)
@@ -256,6 +261,7 @@ def main():
256261
enable_prefix_caching=args.enable_prefix_caching,
257262
quant_policy=args.quant_policy,
258263
dtype=args.dtype,
264+
distributed_executor_backend=args.distributed_executor_backend,
259265
)
260266

261267
if args.use_uvloop:
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# PyTorchEngine Multi-Node Deployment Guide
2+
3+
To support larger-scale model deployment requirements, PyTorchEngine provides multi-node deployment support. Below are the detailed steps for deploying a `tp=16` model across two 8-GPU nodes.
4+
5+
## 1. Create Docker Containers (Optional)
6+
7+
To ensure consistency across the cluster environment, it is recommended to use Docker to set up the cluster. Create containers on each node as follows:
8+
9+
```bash
10+
docker run -it \
11+
--network host \
12+
-v $MODEL_PATH:$CONTAINER_MODEL_PATH \
13+
openmmlab/lmdeploy:latest
14+
```
15+
16+
> \[!IMPORTANT\]
17+
> Ensure that the model is placed in the same directory on all node containers.
18+
19+
## 2. Set Up the Cluster Using Ray
20+
21+
### 2.1 Start the Head Node
22+
23+
Select one node as the **head node** and run the following command in its container:
24+
25+
```bash
26+
ray start --head --port=$DRIVER_PORT
27+
```
28+
29+
### 2.2 Join the Cluster
30+
31+
On the other nodes, use the following command in their containers to join the cluster created by the head node:
32+
33+
```bash
34+
ray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT
35+
```
36+
37+
run `ray status` on head node to check the cluster.
38+
39+
> \[!IMPORTANT\]
40+
> Ensure that `DRIVER_NODE_ADDR` is the address of the head node and `DRIVER_PORT` matches the port number used during the head node initialization.
41+
42+
## 3. Use LMDeploy Interfaces
43+
44+
In the head node's container, you can use all functionalities of PyTorchEngine as usual.
45+
46+
### 3.1 Start the Server
47+
48+
```bash
49+
lmdeploy serve api_server \
50+
$CONTAINER_MODEL_PATH \
51+
--backend pytorch \
52+
--tp 16
53+
```
54+
55+
### 3.2 Use the Pipeline
56+
57+
```python
58+
from lmdeploy import pipeline, PytorchEngineConfig
59+
60+
if __name__ == '__main__':
61+
model_path = '/path/to/model'
62+
backend_config = PytorchEngineConfig(tp=16)
63+
with pipeline(model_path, backend_config=backend_config) as pipe:
64+
outputs = pipe('Hakuna Matata')
65+
```
66+
67+
> \[!NOTE\]
68+
> PyTorchEngine will automatically choose the appropriate launch method (single-node/multi-node) based on the `tp` parameter and the number of devices available in the cluster. If you want to enforce the use of the Ray cluster, you can configure `distributed_executor_backend='ray'` in `PytorchEngineConfig` or use the environment variable `LMDEPLOY_EXECUTOR_BACKEND=ray`.
69+
70+
______________________________________________________________________
71+
72+
By following the steps above, you can successfully deploy PyTorchEngine in a multi-node environment and leverage the Ray cluster for distributed computing.
73+
74+
> \[!WARNING\]
75+
> To achieve better performance, we recommend users to configure a higher-quality network environment (such as using [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand)) to improve engine efficiency.

docs/en/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Documentation
101101
advance/chat_template.md
102102
advance/debug_turbomind.md
103103
advance/structed_output.md
104+
advance/pytorch_multinodes.md
104105

105106
.. toctree::
106107
:maxdepth: 1
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# PyTorchEngine 多节点部署指南
2+
3+
为了支持更大规模的模型部署需求,PyTorchEngine 提供了多节点部署的支持。以下是如何在两个8卡节点上部署 tp=16 模型的详细步骤。
4+
5+
## 1. 创建 Docker 容器(可选)
6+
7+
为了确保集群环境的一致性,建议使用 Docker 搭建集群。在每个节点上创建容器:
8+
9+
```bash
10+
docker run -it \
11+
--network host \
12+
-v $MODEL_PATH:$CONTAINER_MODEL_PATH \
13+
openmmlab/lmdeploy:latest
14+
```
15+
16+
> \[!IMPORTANT\]
17+
> 请确保将模型放置在各个节点容器的相同目录中。
18+
19+
## 2. 使用 ray 搭建集群
20+
21+
### 2.1 启动主节点
22+
23+
选择其中一个节点做为`主节点`,并在该节点的容器中运行以下命令:
24+
25+
```bash
26+
ray start --head --port=$DRIVER_PORT
27+
```
28+
29+
### 2.2 加入集群
30+
31+
在其他节点的容器中,使用以下命令加入主节点所在的集群:
32+
33+
```bash
34+
ray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT
35+
```
36+
37+
完成后可以在主节点使用 `ray status` 查看集群状态,确保所有节点都被成功加入集群。
38+
39+
> \[!IMPORTANT\]
40+
> 请确保 `DRIVER_NODE_ADDR` 为主节点的地址,`DRIVER_PORT` 与主节点初始化时使用的端口号一致。
41+
42+
## 3. 使用 LMDeploy 接口
43+
44+
在主节点的容器中,您可以正常使用 PyTorchEngine 的所有功能。
45+
46+
### 3.1 启动服务 API
47+
48+
```bash
49+
lmdeploy serve api_server \
50+
$CONTAINER_MODEL_PATH \
51+
--backend pytorch \
52+
--tp 16
53+
```
54+
55+
### 3.2 使用 pipeline 接口
56+
57+
```python
58+
from lmdeploy import pipeline, PytorchEngineConfig
59+
60+
if __name__ == '__main__':
61+
model_path = '/path/to/model'
62+
backend_config = PytorchEngineConfig(tp=16)
63+
with pipeline(model_path, backend_config=backend_config) as pipe:
64+
outputs = pipe('Hakuna Matata')
65+
```
66+
67+
> \[!NOTE\]
68+
> PytorchEngine 会根据 tp 数以及集群上的设备数量自动选择合适的启动方式(单机/多机)。如果希望强制使用 ray 集群,可以配置 `PytorchEngineConfig` 中的 `distributed_executor_backend='ray'` 或使用环境变量 `LMDEPLOY_EXECUTOR_BACKEND=ray`
69+
70+
通过以上步骤,您可以成功在多节点环境中部署 PyTorchEngine,并利用 Ray 集群进行分布式计算。
71+
72+
> \[!WARNING\]
73+
> 为了能够得到更好的性能,我们建议用户配置更好的网络环境(比如使用 [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand))以提高引擎运行效率

docs/zh_cn/advance/pytorch_multithread.md

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,6 @@
11
# PyTorchEngine 多线程推理
22

3-
[PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口****协程**来实现高并发,比如:
4-
5-
```python
6-
import asyncio
7-
from lmdeploy import pipeline, PytorchEngineConfig
8-
9-
event_loop = asyncio.new_event_loop()
10-
asyncio.set_event_loop(event_loop)
11-
12-
model_path = 'Llama-3.2-1B-Instruct'
13-
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
14-
15-
async def _gather_output():
16-
tasks = [
17-
pipe.async_batch_infer('Hakuna Matata'),
18-
pipe.async_batch_infer('giraffes are heartless creatures'),
19-
]
20-
return await asyncio.gather(*tasks)
21-
22-
output = asyncio.run(_gather_output())
23-
print(output[0].text)
24-
print(output[1].text)
25-
```
26-
3+
[PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口****协程**来实现高并发,
274
如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。
285

296
```python

docs/zh_cn/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ LMDeploy 工具箱提供以下核心功能:
102102
advance/chat_template.md
103103
advance/debug_turbomind.md
104104
advance/structed_output.md
105+
advance/pytorch_multinodes.md
105106

106107
.. toctree::
107108
:maxdepth: 1

lmdeploy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ class PytorchEngineConfig:
278278
If unspecified, will use the default version.
279279
quant_policy (int): default to 0. When k/v is quantized into 4 or 8
280280
bit, set it to 4 or 8, respectively
281+
distributed_executor_backend (str): backend of distributed backend,
282+
options: ['uni', 'mp', 'ray']
281283
"""
282284
dtype: str = 'auto'
283285
tp: int = 1
@@ -298,6 +300,7 @@ class PytorchEngineConfig:
298300
download_dir: str = None
299301
revision: str = None
300302
quant_policy: Literal[0, 4, 8] = 0
303+
distributed_executor_backend: str = None
301304

302305
def __post_init__(self):
303306
"""Check input validation."""

lmdeploy/pytorch/backends/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,13 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
9090
"""build graph runner."""
9191
from .graph_runner import GraphRunner
9292
return GraphRunner(model, model_config, cache_config, backend_config, device)
93+
94+
@staticmethod
95+
def device_count():
96+
"""get num available devices."""
97+
return None
98+
99+
@staticmethod
100+
def support_ray():
101+
"""support ray."""
102+
return False

lmdeploy/pytorch/backends/cuda/awq_modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from typing import Optional
33

44
import torch
5-
from torch import distributed as dist
5+
6+
import lmdeploy.pytorch.distributed as dist
67

78
from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl
89

lmdeploy/pytorch/backends/cuda/blockedf8_modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
32
from typing import Optional
43

54
import torch
6-
import torch.distributed as dist
75

6+
import lmdeploy.pytorch.distributed as dist
87
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8, quant_fp8
98

109
from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl

0 commit comments

Comments
 (0)