Skip to content

Commit fbeea1a

Browse files
Merge pull request #254 from stochasticai/dev
release 0.1.8
2 parents 4c71825 + 8021076 commit fbeea1a

File tree

9 files changed

+192
-45
lines changed

9 files changed

+192
-45
lines changed

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "xturing"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
description = "Fine-tuning, evaluation and data generation for LLMs"
55

66
authors = [
@@ -43,12 +43,12 @@ keywords = [
4343
dependencies = [
4444
"torch >= 1.9.0",
4545
"pytorch-lightning",
46-
"transformers==4.28.1",
47-
"datasets",
48-
"evaluate",
49-
"bitsandbytes==0.37.2",
46+
"transformers==4.31.0",
47+
"datasets==2.14.5",
48+
"evaluate==0.4.0",
49+
"bitsandbytes==0.41.1",
5050
"sentencepiece",
51-
"deepspeed",
51+
"deepspeed==0.9.5",
5252
"gradio",
5353
"click",
5454
"wget",
@@ -58,7 +58,7 @@ dependencies = [
5858
"openai >= 0.27.0",
5959
"pydantic >= 1.10.0",
6060
"rouge-score >= 0.1.2",
61-
"accelerate",
61+
"accelerate==0.22.0",
6262
"wandb",
6363
]
6464

src/xturing/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.7"
1+
__version__ = "0.1.8"

src/xturing/config/finetuning_config.yaml

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ bloom_lora_int8:
3232
batch_size: 8
3333
max_length: 256
3434

35+
bloom_int8:
36+
learning_rate: 1e-4
37+
weight_decay: 0.01
38+
num_train_epochs: 3
39+
batch_size: 8
40+
max_length: 256
41+
3542
cerebras:
3643
learning_rate: 5e-5
3744
weight_decay: 0.01
@@ -50,6 +57,13 @@ cerebras_lora_int8:
5057
batch_size: 8
5158
max_length: 256
5259

60+
cerebras_int8:
61+
learning_rate: 1e-4
62+
weight_decay: 0.01
63+
num_train_epochs: 3
64+
batch_size: 8
65+
max_length: 256
66+
5367
distilgpt2:
5468
learning_rate: 1e-3
5569
weight_decay: 0.01
@@ -115,6 +129,13 @@ galactica_lora_int8:
115129
batch_size: 8
116130
max_length: 256
117131

132+
galactica_int8:
133+
learning_rate: 1e-4
134+
weight_decay: 0.01
135+
num_train_epochs: 3
136+
batch_size: 8
137+
max_length: 256
138+
118139
generic:
119140
learning_rate: 1e-4
120141
weight_decay: 0.01
@@ -169,6 +190,13 @@ gptj_lora_int8:
169190
batch_size: 8
170191
max_length: 256
171192

193+
gptj_int8:
194+
learning_rate: 1e-4
195+
weight_decay: 0.01
196+
num_train_epochs: 3
197+
batch_size: 8
198+
max_length: 256
199+
172200
gpt2:
173201
learning_rate: 1e-3
174202
weight_decay: 0.01
@@ -187,13 +215,18 @@ gpt2_lora_int8:
187215
num_train_epochs: 3
188216
batch_size: 16
189217

218+
gpt2_int8:
219+
learning_rate: 3e-3
220+
weight_decay: 0.01
221+
num_train_epochs: 3
222+
batch_size: 16
223+
190224
llama:
191225
learning_rate: 5e-5
192226
weight_decay: 0.01
193227
num_train_epochs: 3
194228
optimizer_name: cpu_adam
195229

196-
197230
llama_lora:
198231
learning_rate: 1e-4
199232
weight_decay: 0.01
@@ -207,6 +240,13 @@ llama_lora_int8:
207240
batch_size: 8
208241
max_length: 256
209242

243+
llama_int8:
244+
learning_rate: 1e-4
245+
weight_decay: 0.01
246+
num_train_epochs: 3
247+
batch_size: 8
248+
max_length: 256
249+
210250
llama_lora_kbit:
211251
learning_rate: 3e-4
212252
num_train_epochs: 3
@@ -275,3 +315,10 @@ opt_lora_int8:
275315
num_train_epochs: 3
276316
batch_size: 8
277317
max_length: 256
318+
319+
opt_int8:
320+
learning_rate: 1e-4
321+
weight_decay: 0.01
322+
num_train_epochs: 3
323+
batch_size: 8
324+
max_length: 256

src/xturing/config/generation_config.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ bloom_lora_int8:
2525
max_new_tokens: 256
2626
do_sample: false
2727

28+
# Greedy search
29+
bloom_int8:
30+
max_new_tokens: 256
31+
do_sample: false
32+
2833
# Contrastive search
2934
cerebras:
3035
penalty_alpha: 0.6
@@ -44,6 +49,11 @@ cerebras_lora_int8:
4449
max_new_tokens: 256
4550
do_sample: false
4651

52+
# Greedy search
53+
cerebras_int8:
54+
max_new_tokens: 256
55+
do_sample: false
56+
4757
# Top-p sampling
4858
distilgpt2:
4959
do_sample: true
@@ -102,6 +112,11 @@ galactica_lora_int8:
102112
max_new_tokens: 256
103113
do_sample: false
104114

115+
# Greedy search
116+
galactica_int8:
117+
max_new_tokens: 256
118+
do_sample: false
119+
105120
# Greedy search
106121
generic:
107122
max_new_tokens: 256
@@ -146,6 +161,11 @@ gptj_lora_int8:
146161
max_new_tokens: 256
147162
do_sample: false
148163

164+
# Greedy search
165+
gptj_int8:
166+
max_new_tokens: 256
167+
do_sample: false
168+
149169
# Top-p sampling
150170
gpt2:
151171
do_sample: true
@@ -167,6 +187,13 @@ gpt2_lora_int8:
167187
top_p: 0.92
168188
max_new_tokens: 256
169189

190+
# Top-p sampling
191+
gpt2_int8:
192+
do_sample: true
193+
top_k: 0
194+
top_p: 0.92
195+
max_new_tokens: 256
196+
170197
# Contrastive search
171198
llama:
172199
penalty_alpha: 0.6
@@ -186,6 +213,11 @@ llama_lora_int8:
186213
max_new_tokens: 256
187214
do_sample: false
188215

216+
# Greedy search
217+
llama_int8:
218+
max_new_tokens: 256
219+
do_sample: false
220+
189221
# Greedy search
190222
llama_lora_kbit:
191223
max_new_tokens: 256
@@ -238,3 +270,8 @@ opt_lora:
238270
opt_lora_int8:
239271
max_new_tokens: 256
240272
do_sample: false
273+
274+
# Greedy search
275+
opt_int8:
276+
max_new_tokens: 256
277+
do_sample: false

src/xturing/engines/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
GPTJLoraEngine,
4545
GPTJLoraInt8Engine,
4646
)
47-
from xturing.engines.llama2_engine import LLama2Engine
47+
from xturing.engines.llama2_engine import (
48+
LLama2Engine,
49+
LLama2Int8Engine,
50+
LLama2LoraEngine,
51+
LLama2LoraInt8Engine,
52+
LLama2LoraKbitEngine,
53+
)
4854
from xturing.engines.llama_engine import (
4955
LLamaEngine,
5056
LLamaInt8Engine,
@@ -97,6 +103,10 @@
97103
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
98104
BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine)
99105
BaseEngine.add_to_registry(LLama2Engine.config_name, LLama2Engine)
106+
BaseEngine.add_to_registry(LLama2Int8Engine.config_name, LLama2Int8Engine)
107+
BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine)
108+
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
109+
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
100110
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
101111
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
102112
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)

src/xturing/engines/generic_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464

6565

6666
class GenericLoraKbitEngine(CausalLoraKbitEngine):
67-
config_name: str = "generic+lora_kbit_engine"
67+
config_name: str = "generic_lora_kbit_engine"
6868

6969
def __init__(
7070
self,
@@ -75,7 +75,6 @@ def __init__(
7575
super().__init__(
7676
model_name=model_name,
7777
weights_path=weights_path,
78-
load_4bit=True,
7978
target_modules=target_modules,
8079
)
8180

src/xturing/models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
LlamaLoraInt8,
3737
LlamaLoraKbit,
3838
)
39-
from xturing.models.llama2 import Llama2
39+
from xturing.models.llama2 import (
40+
Llama2,
41+
Llama2Int8,
42+
Llama2Lora,
43+
Llama2LoraInt8,
44+
Llama2LoraKbit,
45+
)
4046
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
4147
from xturing.models.stable_diffusion import StableDiffusion
4248

@@ -78,6 +84,10 @@
7884
BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8)
7985
BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit)
8086
BaseModel.add_to_registry(Llama2.config_name, Llama2)
87+
BaseModel.add_to_registry(Llama2Int8.config_name, Llama2Int8)
88+
BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora)
89+
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
90+
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
8191
BaseModel.add_to_registry(OPT.config_name, OPT)
8292
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
8393
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)

0 commit comments

Comments
 (0)