Skip to content

Commit 8fe76b9

Browse files
committed
Specify ASCEND NPU for inference.
1 parent 1cd4b74 commit 8fe76b9

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

fastchat/serve/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Type "!!save <filename>" to save the conversation history to a json file.
1414
- Type "!!load <filename>" to load a conversation history from a json file.
1515
"""
16+
1617
import argparse
1718
import os
1819
import re
@@ -197,6 +198,14 @@ def main(args):
197198
)
198199
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
199200
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
201+
if len(args.gpus.split(",")) == 1:
202+
try:
203+
import torch_npu
204+
205+
torch.npu.set_device(int(args.gpus))
206+
print(f"NPU is available, now model is running on npu:{args.gpus}")
207+
except ModuleNotFoundError:
208+
pass
200209
if args.enable_exllama:
201210
exllama_config = ExllamaConfig(
202211
max_seq_len=args.exllama_max_seq_len,

fastchat/serve/model_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
A model worker that executes the model.
33
"""
4+
45
import argparse
56
import base64
67
import gc
@@ -351,6 +352,14 @@ def create_model_worker():
351352
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
352353
)
353354
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
355+
if len(args.gpus.split(",")) == 1:
356+
try:
357+
import torch_npu
358+
359+
torch.npu.set_device(int(args.gpus))
360+
print(f"NPU is available, now model is running on npu:{args.gpus}")
361+
except ModuleNotFoundError:
362+
pass
354363

355364
gptq_config = GptqConfig(
356365
ckpt=args.gptq_ckpt or args.model_path,

fastchat/serve/multi_model_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
We recommend using this with multiple Peft models (with `peft` in the name)
1212
where all Peft models are trained on the exact same base model.
1313
"""
14+
1415
import argparse
1516
import asyncio
1617
import dataclasses
@@ -206,6 +207,14 @@ def create_multi_model_worker():
206207
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
207208
)
208209
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
210+
if len(args.gpus.split(",")) == 1:
211+
try:
212+
import torch_npu
213+
214+
torch.npu.set_device(int(args.gpus))
215+
print(f"NPU is available, now model is running on npu:{args.gpus}")
216+
except ModuleNotFoundError:
217+
pass
209218

210219
gptq_config = GptqConfig(
211220
ckpt=args.gptq_ckpt or args.model_path,

0 commit comments

Comments
 (0)