Skip to content

Commit 25065ca

Browse files
Merge pull request #156 from stochasticai/dev
Release 0.1.0
2 parents 3531155 + 5d97147 commit 25065ca

File tree

18 files changed

+1100
-327
lines changed

18 files changed

+1100
-327
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ With `xturing` you can,
3333

3434
<br>
3535

36-
## 🌟 INT4 fine-tuning with LLaMA LoRA
36+
## 🌟 INT4 fine-tuning and generation with LLaMA LoRA
3737

38-
We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning demo. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.
38+
We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning and generation integration. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.
3939

4040
More information about INT4 fine-tuning and benchmarks can be found in the [INT4 README](examples/int4_finetuning/README.md).
4141

@@ -146,6 +146,7 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
146146
- [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation
147147
- [x] Added fine-tuned checkpoints for some models to the hub
148148
- [x] INT4 LLaMA LoRA fine-tuning demo
149+
- [x] INT4 LLaMA LoRA fine-tuning with INT4 generation
149150
- [ ] Evaluation of LLM models
150151
- [ ] Support for Stable Diffusion
151152

examples/int4_finetuning/LLaMA_lora_int4.ipynb

Lines changed: 45 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -31,119 +31,80 @@
3131
},
3232
"outputs": [],
3333
"source": [
34-
"!pip install xturing --upgrade"
34+
"!pip install xturing --upgrade\n",
35+
"!pip install xturing[int4] --upgrade"
3536
]
3637
},
3738
{
38-
"cell_type": "code",
39-
"execution_count": 7,
39+
"cell_type": "markdown",
4040
"metadata": {
4141
"collapsed": false
4242
},
43-
"outputs": [],
4443
"source": [
45-
"import subprocess\n",
46-
"from pathlib import Path\n",
47-
"\n",
48-
"def pull_docker_image(image):\n",
49-
" cmd = [\"docker\", \"pull\", image]\n",
50-
" subprocess.run(cmd, check=True)\n",
51-
"\n",
52-
"\n",
53-
"def run_docker_container(image, port_mapping, env_vars=None, gpus=None, volumes=None):\n",
54-
" cmd = [\"docker\", \"container\", \"run\"]\n",
55-
"\n",
56-
" if env_vars is None:\n",
57-
" env_vars = {}\n",
58-
"\n",
59-
" if volumes is None:\n",
60-
" volumes = {}\n",
61-
"\n",
62-
" if gpus is not None:\n",
63-
" cmd.extend([\"--gpus\", gpus])\n",
64-
"\n",
65-
" for key, value in env_vars.items():\n",
66-
" cmd.extend([\"-e\", f\"{key}={value}\"])\n",
67-
"\n",
68-
" for local_path, container_path in volumes.items():\n",
69-
" cmd.extend([\"-v\", f\"{str(Path(local_path).resolve())}:{container_path}\"])\n",
70-
"\n",
71-
" cmd.extend([\"-p\", port_mapping, image])\n",
72-
"\n",
73-
" subprocess.run(cmd)"
44+
"## 2. Load model and dataset"
7445
]
7546
},
7647
{
77-
"cell_type": "markdown",
48+
"cell_type": "code",
49+
"execution_count": null,
7850
"metadata": {
79-
"collapsed": false
51+
"collapsed": false,
52+
"pycharm": {
53+
"is_executing": true
54+
}
8055
},
56+
"outputs": [],
8157
"source": [
82-
"## 2. Load and run docker image"
58+
"from xturing.datasets.instruction_dataset import InstructionDataset\n",
59+
"from xturing.models import BaseModel\n",
60+
"\n",
61+
"instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n",
62+
"# Initializes the model\n",
63+
"model = BaseModel.create(\"llama_lora_int4\")"
8364
]
8465
},
8566
{
8667
"cell_type": "markdown",
68+
"source": [
69+
"## 3. Start the finetuning"
70+
],
8771
"metadata": {
8872
"collapsed": false
89-
},
90-
"source": [
91-
"1. Install Docker on your machine if you haven't already. You can follow the [official Docker documentation](https://docs.docker.com/engine/install/) for installation instructions.\n",
92-
"2. Install NVIDIA Container Toolkit\n",
93-
" ```bash\n",
94-
" sudo apt-get install -y nvidia-docker2\n",
95-
" ```\n",
96-
"3. Run the Docker daemon\n",
97-
" ```bash\n",
98-
" sudo systemctl start docker\n",
99-
" ```\n"
100-
]
73+
}
10174
},
10275
{
10376
"cell_type": "code",
10477
"execution_count": null,
105-
"metadata": {
106-
"collapsed": false,
107-
"pycharm": {
108-
"is_executing": true
109-
}
110-
},
11178
"outputs": [],
11279
"source": [
113-
"image = \"public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\"\n",
114-
"port_mapping = \"5000:5000\"\n",
115-
"env_vars = {\n",
116-
" \"WANDB_MODE\": \"dryrun\",\n",
117-
" \"MICRO_BATCH_SIZE\": \"1\", # change this to increase your micro batch size\n",
118-
"}\n",
119-
"# if you want to log results to wandb, set the following env var\n",
120-
"# env_vars = {\n",
121-
"# \"WANDB_API_KEY\": \"<your_wandb_api_key>\",\n",
122-
"# \"WANDB_PROJECT\": \"your_project_name\",\n",
123-
"# \"WANDB_ENTITY\": \"your_entity_name\",\n",
124-
"# # Add more environment variables as needed\n",
125-
"# }\n",
126-
"volumes = {\n",
127-
" # \"<where to save model>\": \"/model\",\n",
128-
" \"../llama/alpaca_data\": \"/data\", # change this to your data path if you want\n",
129-
"}\n",
130-
"gpus = \"all\"\n",
131-
"\n",
132-
"pull_docker_image(image)\n",
133-
"\n",
134-
"run_docker_container(image, port_mapping, env_vars, gpus, volumes)"
135-
]
80+
"# Finetuned the model\n",
81+
"model.finetune(dataset=instruction_dataset)"
82+
],
83+
"metadata": {
84+
"collapsed": false
85+
}
13686
},
13787
{
13888
"cell_type": "markdown",
139-
"metadata": {},
14089
"source": [
141-
"## Alternately, you can run the example using CLI command:\n",
142-
"\n",
143-
"```bash\n",
144-
"docker run -p 5000:5000 --gpus all -e WANDB_MODE=dryrun -e MICRO_BATCH_SIZE=1 -v /absolute/path/to/alpaca/data:/data public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\n",
145-
"```"
146-
]
90+
"## 4. Generate an output text with the fine-tuned model"
91+
],
92+
"metadata": {
93+
"collapsed": false
94+
}
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"outputs": [],
100+
"source": [
101+
"# Once the model has been finetuned, you can start doing inferences\n",
102+
"output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n",
103+
"print(\"Generated output by the model: {}\".format(output))"
104+
],
105+
"metadata": {
106+
"collapsed": false
107+
}
147108
}
148109
],
149110
"metadata": {

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "xturing"
3-
version = "0.0.10"
3+
version = "0.1.0"
44
description = "Fine-tuning, evaluation and data generation for LLMs"
55

66
authors = [
@@ -64,6 +64,10 @@ dependencies = [
6464
[project.scripts]
6565
xturing = "xturing.cli:xturing"
6666

67+
[project.optional-dependencies]
68+
int4 = [
69+
"torch >= 2.0"
70+
]
6771

6872
[project.urls]
6973
homepage = "https://xturing.stochastic.ai/"

src/xturing/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.10"
1+
__version__ = "0.1.0"

src/xturing/config/finetuning_config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ llama_lora_int8:
3333
batch_size: 8
3434
max_length: 256
3535

36+
llama_lora_int4:
37+
learning_rate: 1e-4
38+
weight_decay: 0.01
39+
num_train_epochs: 3
40+
batch_size: 8
41+
max_length: 256
42+
3643
gptj:
3744
learning_rate: 5e-5
3845
weight_decay: 0.01

src/xturing/config/generation_config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ llama_lora_int8:
2727
max_new_tokens: 256
2828
do_sample: false
2929

30+
# Contrastive search
31+
llama_lora_int4:
32+
penalty_alpha: 0.6
33+
top_k: 4
34+
max_new_tokens: 256
35+
do_sample: false
36+
3037
# Contrastive search
3138
gptj:
3239
penalty_alpha: 0.6

src/xturing/engines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
LLamaInt8Engine,
2626
LlamaLoraEngine,
2727
LlamaLoraInt8Engine,
28+
LlamaLoraInt4Engine,
2829
)
2930
from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine
3031

@@ -42,6 +43,7 @@
4243
BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine)
4344
BaseEngine.add_to_registry(LLamaInt8Engine.config_name, LLamaInt8Engine)
4445
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
46+
BaseEngine.add_to_registry(LlamaLoraInt4Engine.config_name, LlamaLoraInt4Engine)
4547
BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine)
4648
BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine)
4749
BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine)

src/xturing/engines/causal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(
146146

147147
lora_config = LoraConfig(
148148
r=8,
149-
lora_alpha=32,
149+
lora_alpha=16,
150150
target_modules=target_modules,
151151
lora_dropout=0.05,
152152
bias="none",

src/xturing/engines/llama_engine.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional, Tuple, Union
4+
import transformers
45

56
import torch
7+
from torch import nn
68

79
from xturing.engines.causal import CausalEngine, CausalLoraEngine
810
from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
911
from xturing.engines.lora_engine import prepare_model_for_int8_training
10-
12+
from xturing.engines.quant_utils import make_quant, autotune_warmup
13+
from xturing.utils.hub import ModelHub
1114

1215
class LLamaEngine(CausalEngine):
1316
config_name: str = "llama_engine"
@@ -98,3 +101,84 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
98101
load_8bit=True,
99102
target_modules=["q_proj", "v_proj"],
100103
)
104+
105+
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
106+
if type(module) in layers:
107+
return {name: module}
108+
res = {}
109+
for name1, child in module.named_children():
110+
res.update(find_layers(
111+
child, layers=layers, name=name + '.' + name1 if name != '' else name1
112+
))
113+
return res
114+
115+
class LlamaLoraInt4Engine(CausalLoraEngine):
116+
config_name: str = "llama_lora_int4_engine"
117+
118+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
119+
model_name = "decapoda-research/llama-7b-hf"
120+
121+
if weights_path is None:
122+
weights_path = ModelHub().load("x/llama_lora_int4")
123+
124+
config = LlamaConfig.from_pretrained(model_name)
125+
126+
saved_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
127+
saved_uniform_ = torch.nn.init.uniform_
128+
saved_normal_ = torch.nn.init.normal_
129+
130+
def noop(*args, **kwargs):
131+
pass
132+
133+
torch.nn.init.kaiming_uniform_ = noop
134+
torch.nn.init.uniform_ = noop
135+
torch.nn.init.normal_ = noop
136+
137+
torch.set_default_dtype(torch.half)
138+
transformers.modeling_utils._init_weights = False
139+
torch.set_default_dtype(torch.half)
140+
model = LlamaForCausalLM(config)
141+
torch.set_default_dtype(torch.float)
142+
model = model.eval()
143+
144+
layers = find_layers(model)
145+
146+
for name in ['lm_head']:
147+
if name in layers:
148+
del layers[name]
149+
150+
wbits = 4
151+
groupsize = 128
152+
warmup_autotune=True
153+
154+
make_quant(model, layers, wbits, groupsize)
155+
156+
157+
model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False)
158+
159+
if warmup_autotune:
160+
autotune_warmup(model)
161+
162+
model.seqlen = 2048
163+
164+
model.gptq = True
165+
166+
model.gradient_checkpointing_enable()
167+
model.enable_input_require_grads()
168+
169+
tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False)
170+
tokenizer.pad_token = tokenizer.eos_token
171+
tokenizer.pad_token_id = tokenizer.eos_token_id
172+
173+
super().__init__(
174+
model=model,
175+
tokenizer=tokenizer,
176+
target_modules=[
177+
"q_proj",
178+
"v_proj",
179+
]
180+
)
181+
182+
torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_
183+
torch.nn.init.uniform_ = saved_uniform_
184+
torch.nn.init.normal_ = saved_normal_

0 commit comments

Comments
 (0)