Skip to content

Commit 83d36e2

Browse files
committed
Specify ASCEND NPU for inference.
1 parent 1cd4b74 commit 83d36e2

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

fastchat/model/model_adapter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ def add_model_args(parser):
512512
help="A single GPU like 1 or multiple GPUs like 0,2",
513513
)
514514
parser.add_argument("--num-gpus", type=int, default=1)
515+
parser.add_argument(
516+
"--npus",
517+
type=str,
518+
default=None,
519+
help="A single NPU like 1 or multiple NPUs like 0,2",
520+
)
521+
parser.add_argument("--num-npus", type=int, default=1)
515522
parser.add_argument(
516523
"--max-gpu-memory",
517524
type=str,

fastchat/serve/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Type "!!save <filename>" to save the conversation history to a json file.
1414
- Type "!!load <filename>" to load a conversation history from a json file.
1515
"""
16+
1617
import argparse
1718
import os
1819
import re
@@ -197,6 +198,14 @@ def main(args):
197198
)
198199
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
199200
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
201+
if args.npus:
202+
if len(args.npus.split(",")) < args.num_npus:
203+
raise ValueError(
204+
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
205+
)
206+
if len(args.npus.split(",")) == 1:
207+
import torch_npu
208+
torch.npu.set_device(int(args.npus))
200209
if args.enable_exllama:
201210
exllama_config = ExllamaConfig(
202211
max_seq_len=args.exllama_max_seq_len,

fastchat/serve/model_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
A model worker that executes the model.
33
"""
4+
45
import argparse
56
import base64
67
import gc
@@ -351,6 +352,14 @@ def create_model_worker():
351352
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
352353
)
353354
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
355+
if args.npus:
356+
if len(args.npus.split(",")) < args.num_npus:
357+
raise ValueError(
358+
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
359+
)
360+
if len(args.npus.split(",")) == 1:
361+
import torch_npu
362+
torch.npu.set_device(int(args.npus))
354363

355364
gptq_config = GptqConfig(
356365
ckpt=args.gptq_ckpt or args.model_path,

fastchat/serve/multi_model_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
We recommend using this with multiple Peft models (with `peft` in the name)
1212
where all Peft models are trained on the exact same base model.
1313
"""
14+
1415
import argparse
1516
import asyncio
1617
import dataclasses
@@ -206,6 +207,14 @@ def create_multi_model_worker():
206207
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
207208
)
208209
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
210+
if args.npus:
211+
if len(args.npus.split(",")) < args.num_npus:
212+
raise ValueError(
213+
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
214+
)
215+
if len(args.npus.split(",")) == 1:
216+
import torch_npu
217+
torch.npu.set_device(int(args.npus))
209218

210219
gptq_config = GptqConfig(
211220
ckpt=args.gptq_ckpt or args.model_path,

0 commit comments

Comments
 (0)