14
14
AutoTokenizer ,
15
15
AutoModelForCausalLM ,
16
16
BitsAndBytesConfig ,
17
+ PreTrainedTokenizer ,
17
18
)
18
19
from transformers .utils .versions import require_version
19
20
@@ -28,10 +29,15 @@ class BaseModelAdapter:
28
29
29
30
model_names = []
30
31
31
- def match (self , model_name ):
32
+ def match (self , model_name ) -> bool :
32
33
return any (m in model_name for m in self .model_names ) if self .model_names else True
33
34
34
- def load_model (self , model_name_or_path : Optional [str ] = None , adapter_model : Optional [str ] = None , ** kwargs ):
35
+ def load_model (
36
+ self ,
37
+ model_name_or_path : Optional [str ] = None ,
38
+ adapter_model : Optional [str ] = None ,
39
+ ** kwargs
40
+ ):
35
41
""" Load model through transformers. """
36
42
model_name_or_path = self .default_model_name_or_path if model_name_or_path is None else model_name_or_path
37
43
tokenizer_kwargs = {"trust_remote_code" : True , "use_fast" : False }
@@ -149,7 +155,9 @@ def load_lora_model(self, model, adapter_model, model_kwargs):
149
155
torch_dtype = model_kwargs .get ("torch_dtype" , torch .float16 ),
150
156
)
151
157
152
- def load_adapter_model (self , model , tokenizer , adapter_model , is_chatglm , model_kwargs , ** kwargs ):
158
+ def load_adapter_model (
159
+ self , model , tokenizer , adapter_model , is_chatglm , model_kwargs , ** kwargs
160
+ ):
153
161
use_ptuning_v2 = kwargs .get ("use_ptuning_v2" , False )
154
162
resize_embeddings = kwargs .get ("resize_embeddings" , False )
155
163
if adapter_model and resize_embeddings and not is_chatglm :
@@ -176,7 +184,7 @@ def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_
176
184
177
185
return model
178
186
179
- def post_tokenizer (self , tokenizer ):
187
+ def post_tokenizer (self , tokenizer ) -> PreTrainedTokenizer :
180
188
return tokenizer
181
189
182
190
@property
@@ -264,6 +272,20 @@ def default_model_name_or_path(self):
264
272
return "THUDM/chatglm2-6b"
265
273
266
274
275
+ class Chatglm3ModelAdapter (ChatglmModelAdapter ):
276
+ """ https://github.com/THUDM/ChatGLM-6B """
277
+
278
+ model_names = ["chatglm3" ]
279
+
280
+ @property
281
+ def tokenizer_kwargs (self ):
282
+ return {"encode_special_tokens" : True }
283
+
284
+ @property
285
+ def default_model_name_or_path (self ):
286
+ return "THUDM/chatglm3-6b"
287
+
288
+
267
289
class LlamaModelAdapter (BaseModelAdapter ):
268
290
""" https://github.com/project-baize/baize-chatbot """
269
291
0 commit comments