Skip to content

Commit 5b45f34

Browse files
author
xusenlin
committed
Improve code and add support internlm2
1 parent 389aed2 commit 5b45f34

27 files changed

+591
-686
lines changed

.env.example

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ STREAM_INTERVERL=2
1414
PROMPT_NAME=
1515

1616
# device related
17-
DEVICE=cuda
18-
DEVICE_MAP=
17+
DEVICE=
18+
19+
# "auto", "cuda:0", "cuda:1", ...
20+
DEVICE_MAP=auto
1921
GPUS=
2022
NUM_GPUs=1
2123
DTYPE=half
2224

23-
# patch related
24-
PATCH_TYPE=
2525

2626
# api related
2727
API_PREFIX=/v1

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
## 📢 新闻
2222

23+
+ 【2024.01.19】 添加 [InternLM2](https://github.com/InternLM/InternLM) 模型支持,[启动方式](https://github.com/xusenlinzy/api-for-open-llm/blob/master/docs/SCRIPT.md#internlm2)
24+
25+
2326
+ 【2023.12.21】 添加 [TGI](https://github.com/huggingface/text-generation-inference) 生成接口转发和 [TEI](https://github.com/huggingface/text-embeddings-inference) embedding 接口转发
2427

2528

@@ -113,6 +116,7 @@
113116
| [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B) | Qwen | 7B | en, zh | [Qwen/Qwen-7B-Chat](https://huggingface.co/baichuan-inc/Qwen/Qwen-7B-Chat) |
114117
| [baichuan-13b-chat](https://github.com/baichuan-inc/Baichuan-13B) | Baichuan | 13B | en, zh | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) |
115118
| [InternLM](https://github.com/InternLM/InternLM) | InternLM | 7B | en, zh | [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) |
119+
| [InternLM2](https://github.com/InternLM/InternLM) | InternLM2 | 20B | en, zh | [internlm/internlm2-chat-20b](https://huggingface.co/internlm/internlm2-chat-20b) |
116120
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | GLM | 6/130B | en, zh | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
117121
| [baichaun-7b](https://github.com/baichuan-inc/baichuan-7B) | Baichuan | 7B | en, zh | [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) |
118122
| [Guanaco](https://github.com/artidoro/qlora/tree/main) | LLaMA | 7/33/65B | en | [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) |

api/adapter/loader.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
from typing import (
4+
TYPE_CHECKING,
5+
Optional,
6+
Tuple,
7+
Any,
8+
)
9+
10+
from transformers import (
11+
AutoConfig,
12+
AutoModelForCausalLM,
13+
AutoTokenizer,
14+
)
15+
16+
from .patcher import (
17+
patch_config,
18+
patch_tokenizer,
19+
patch_model,
20+
)
21+
22+
if TYPE_CHECKING:
23+
from transformers import PreTrainedModel, PreTrainedTokenizer
24+
25+
26+
def _load_model_and_tokenizer(
27+
model_name_or_path: str,
28+
use_fast_tokenizer: Optional[bool] = False,
29+
dtype: Optional[str] = None,
30+
device_map: Optional[Any] = None,
31+
load_in_8bit: Optional[bool] = False,
32+
load_in_4bit: Optional[bool] = False,
33+
rope_scaling: Optional[str] = None,
34+
flash_attn: Optional[bool] = False,
35+
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
36+
r"""
37+
Loads pretrained model and tokenizer.
38+
39+
Support inference.
40+
"""
41+
config_kwargs = {"trust_remote_code": True}
42+
43+
tokenizer = AutoTokenizer.from_pretrained(
44+
model_name_or_path,
45+
use_fast=use_fast_tokenizer,
46+
trust_remote_code=True,
47+
)
48+
patch_tokenizer(tokenizer)
49+
50+
config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs)
51+
patch_config(
52+
config,
53+
config_kwargs,
54+
dtype,
55+
rope_scaling=rope_scaling,
56+
flash_attn=flash_attn,
57+
load_in_4bit=load_in_4bit,
58+
load_in_8bit=load_in_8bit,
59+
)
60+
61+
if device_map:
62+
config_kwargs["device_map"] = device_map
63+
64+
model = AutoModelForCausalLM.from_pretrained(
65+
model_name_or_path,
66+
config=config,
67+
low_cpu_mem_usage=True,
68+
**config_kwargs
69+
)
70+
71+
patch_model(model)
72+
model.eval()
73+
74+
return model, tokenizer
75+
76+
77+
def load_model_and_tokenizer(
78+
model_name: str,
79+
model_name_or_path: str,
80+
use_fast_tokenizer: Optional[bool] = False,
81+
dtype: Optional[str] = None,
82+
device_map: Optional[Any] = None,
83+
load_in_8bit: Optional[bool] = False,
84+
load_in_4bit: Optional[bool] = False,
85+
rope_scaling: Optional[str] = None,
86+
flash_attn: Optional[bool] = False,
87+
**kwargs,
88+
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
89+
try:
90+
model, tokenizer = _load_model_and_tokenizer(
91+
model_name_or_path,
92+
use_fast_tokenizer,
93+
dtype,
94+
device_map,
95+
load_in_8bit,
96+
load_in_4bit,
97+
rope_scaling,
98+
flash_attn,
99+
)
100+
except:
101+
from .model import load_model_and_tokenizer_old
102+
103+
model, tokenizer = load_model_and_tokenizer_old(
104+
model_name,
105+
model_name_or_path,
106+
dtype=dtype,
107+
load_in_8bit=load_in_8bit,
108+
load_in_4bit=load_in_4bit,
109+
device_map=device_map,
110+
**kwargs,
111+
)
112+
113+
return model, tokenizer

api/adapter/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
""" this file is overdated and will be used """
2+
13
import os
24
import sys
35
from typing import List, Optional, Any, Dict, Tuple
@@ -277,7 +279,7 @@ def get_model_adapter(model_name: str) -> BaseModelAdapter:
277279
raise ValueError(f"No valid model adapter for {model_name}")
278280

279281

280-
def load_model(
282+
def load_model_and_tokenizer_old(
281283
model_name: str,
282284
model_name_or_path: Optional[str] = None,
283285
adapter_model: Optional[str] = None,

api/adapter/patcher.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
""" from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llmtuner/model/patcher.py """
2+
from __future__ import annotations
3+
4+
import importlib.metadata
5+
import importlib.util
6+
import os
7+
from types import MethodType
8+
from typing import (
9+
TYPE_CHECKING,
10+
Any,
11+
Dict,
12+
Optional,
13+
)
14+
15+
import torch
16+
from loguru import logger
17+
from transformers import (
18+
PreTrainedModel,
19+
PreTrainedTokenizerBase,
20+
BitsAndBytesConfig,
21+
)
22+
from transformers.utils import (
23+
is_torch_bf16_gpu_available,
24+
is_torch_cuda_available,
25+
is_torch_npu_available
26+
)
27+
from transformers.utils.versions import require_version
28+
29+
if TYPE_CHECKING:
30+
from transformers import PretrainedConfig, PreTrainedTokenizer
31+
32+
33+
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
34+
try:
35+
_is_bf16_available = is_torch_bf16_gpu_available()
36+
except:
37+
_is_bf16_available = False
38+
39+
40+
def is_package_available(name: str) -> bool:
41+
return importlib.util.find_spec(name) is not None
42+
43+
44+
def get_package_version(name: str) -> str:
45+
try:
46+
return importlib.metadata.version(name)
47+
except:
48+
return "0.0.0"
49+
50+
51+
def is_flash_attn2_available():
52+
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
53+
54+
55+
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
56+
r"""
57+
Infers the optimal dtype according to the model_dtype and device compatibility.
58+
"""
59+
if _is_bf16_available and model_dtype == torch.bfloat16:
60+
return torch.bfloat16
61+
elif _is_fp16_available:
62+
return torch.float16
63+
else:
64+
return torch.float32
65+
66+
67+
def _configure_rope(config: "PretrainedConfig", rope_scaling: str = None) -> None:
68+
if not hasattr(config, "rope_scaling"):
69+
logger.warning("Current model does not support RoPE scaling.")
70+
return
71+
72+
scaling_factor = 2.0
73+
setattr(config, "rope_scaling", {"type": rope_scaling, "factor": scaling_factor})
74+
logger.info(f"Using {rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}.")
75+
76+
77+
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
78+
if not is_flash_attn2_available():
79+
logger.warning("FlashAttention2 is not installed.")
80+
return
81+
82+
config_kwargs["use_flash_attention_2"] = True
83+
logger.info("Using FlashAttention-2 for faster and inference.")
84+
85+
86+
def _configure_quantization(
87+
config_kwargs: Dict[str, Any],
88+
load_in_8bits: bool = False,
89+
load_in_4bits: bool = False,
90+
) -> None:
91+
92+
if load_in_8bits:
93+
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
94+
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
95+
logger.info("Quantizing model to 8 bit.")
96+
97+
elif load_in_4bits:
98+
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
99+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
100+
load_in_4bit=True,
101+
bnb_4bit_compute_dtype=config_kwargs.get("torch_dtype", torch.float16),
102+
bnb_4bit_use_double_quant=True,
103+
bnb_4bit_quant_type="nf4",
104+
)
105+
logger.info("Quantizing model to 4 bit.")
106+
107+
if load_in_8bits or load_in_4bits:
108+
config_kwargs["device_map"] = {"": get_current_device()}
109+
else:
110+
config_kwargs["device_map"] = get_current_device()
111+
112+
113+
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
114+
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
115+
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
116+
117+
if tokenizer.eos_token_id is None:
118+
tokenizer.eos_token = "<|endoftext|>"
119+
logger.info(f"Add eos token: {tokenizer.eos_token}")
120+
121+
if tokenizer.pad_token_id is None:
122+
if tokenizer.unk_token_id is not None:
123+
tokenizer.pad_token = tokenizer.unk_token
124+
else:
125+
tokenizer.pad_token = tokenizer.eos_token
126+
logger.info(f"Add pad token: {tokenizer.pad_token}")
127+
128+
129+
def patch_config(
130+
config: "PretrainedConfig",
131+
config_kwargs: Dict[str, Any],
132+
compute_dtype: Optional[str] = None,
133+
**kwargs,
134+
):
135+
if compute_dtype is None: # priority: bf16 > fp16 > fp32
136+
compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
137+
else:
138+
_DTYPE_MAP = {
139+
"half": torch.float16,
140+
"float16": torch.float16,
141+
"bfloat16": torch.bfloat16,
142+
"float32": torch.float32,
143+
}
144+
compute_dtype = _DTYPE_MAP.get(compute_dtype, torch.float16)
145+
146+
config_kwargs["torch_dtype"] = compute_dtype
147+
148+
if getattr(config, "model_type", None) == "qwen":
149+
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
150+
setattr(config, dtype_name, compute_dtype == dtype)
151+
152+
rope_scaling = kwargs.get("rope_scaling", None)
153+
if rope_scaling is not None:
154+
_configure_rope(config, rope_scaling)
155+
156+
if kwargs.get("flash_attn", False):
157+
_configure_flashattn(config_kwargs)
158+
159+
_configure_quantization(
160+
config_kwargs,
161+
kwargs.get("load_in_8bit", False),
162+
kwargs.get("load_in_4bit", False),
163+
)
164+
165+
166+
def patch_model(model: "PreTrainedModel") -> None:
167+
if "GenerationMixin" not in str(model.generate.__func__):
168+
model.generate = MethodType(PreTrainedModel.generate, model)
169+
170+
171+
def get_current_device() -> torch.device:
172+
r"""
173+
Gets the current available device.
174+
"""
175+
if is_torch_npu_available():
176+
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
177+
elif is_torch_cuda_available():
178+
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
179+
else:
180+
device = "cpu"
181+
182+
return torch.device(device)

0 commit comments

Comments
 (0)