Skip to content

Commit b28ab1e

Browse files
author
xusenlin
committed
refactor
1 parent 49e3d83 commit b28ab1e

File tree

14 files changed

+966
-221
lines changed

14 files changed

+966
-221
lines changed

api/apapter/model.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AutoTokenizer,
1515
AutoModelForCausalLM,
1616
BitsAndBytesConfig,
17+
PreTrainedTokenizer,
1718
)
1819
from transformers.utils.versions import require_version
1920

@@ -28,10 +29,15 @@ class BaseModelAdapter:
2829

2930
model_names = []
3031

31-
def match(self, model_name):
32+
def match(self, model_name) -> bool:
3233
return any(m in model_name for m in self.model_names) if self.model_names else True
3334

34-
def load_model(self, model_name_or_path: Optional[str] = None, adapter_model: Optional[str] = None, **kwargs):
35+
def load_model(
36+
self,
37+
model_name_or_path: Optional[str] = None,
38+
adapter_model: Optional[str] = None,
39+
**kwargs
40+
):
3541
""" Load model through transformers. """
3642
model_name_or_path = self.default_model_name_or_path if model_name_or_path is None else model_name_or_path
3743
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
@@ -149,7 +155,9 @@ def load_lora_model(self, model, adapter_model, model_kwargs):
149155
torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
150156
)
151157

152-
def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs):
158+
def load_adapter_model(
159+
self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs
160+
):
153161
use_ptuning_v2 = kwargs.get("use_ptuning_v2", False)
154162
resize_embeddings = kwargs.get("resize_embeddings", False)
155163
if adapter_model and resize_embeddings and not is_chatglm:
@@ -176,7 +184,7 @@ def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_
176184

177185
return model
178186

179-
def post_tokenizer(self, tokenizer):
187+
def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer:
180188
return tokenizer
181189

182190
@property
@@ -264,6 +272,20 @@ def default_model_name_or_path(self):
264272
return "THUDM/chatglm2-6b"
265273

266274

275+
class Chatglm3ModelAdapter(ChatglmModelAdapter):
276+
""" https://github.com/THUDM/ChatGLM-6B """
277+
278+
model_names = ["chatglm3"]
279+
280+
@property
281+
def tokenizer_kwargs(self):
282+
return {"encode_special_tokens": True}
283+
284+
@property
285+
def default_model_name_or_path(self):
286+
return "THUDM/chatglm3-6b"
287+
288+
267289
class LlamaModelAdapter(BaseModelAdapter):
268290
""" https://github.com/project-baize/baize-chatbot """
269291

0 commit comments

Comments
 (0)