-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
174 lines (142 loc) · 6.65 KB
/
train.py
File metadata and controls
174 lines (142 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
JPLoRA 训练脚本
用 LoRA 微调 Qwen2.5,让模型成为明略可信 AI 专家
LoRA 核心原理:
原始权重矩阵 W 冻结,训练两个低秩矩阵 A 和 B
实际更新量 = A @ B,rank << 原始维度
好处:参数量大幅减少(<1%),训练快,不破坏原始能力
"""
import json
import os
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from peft import LoraConfig, get_peft_model, TaskType
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
# ─── 配置 ────────────────────────────────────────────────────────────────────
MODEL_ID = "models/Qwen2.5-7B-Instruct"
OUTPUT_DIR = "output/lora-7b"
DATA_PATH = "data/train.json"
# LoRA 超参数
LORA_RANK = 32 # 低秩矩阵的秩,越大拟合能力越强,但参数更多
LORA_ALPHA = 64 # 缩放系数(通常 = 2 × rank)
LORA_DROPOUT = 0.05
# 训练超参数
NUM_EPOCHS = 5
BATCH_SIZE = 1 # 0.5B 模型,batch=1 即可
GRAD_ACCUM = 8 # 有效 batch = 8
LR = 2e-4
MAX_SEQ_LEN = 512 # 明略 QA 都较短,512 足够
# ─── Dashboard 日志回调 ────────────────────────────────────────────────────────
class JsonlLogCallback(TrainerCallback):
"""每 logging_steps 步把训练指标写入 output/train_log.jsonl,供 Dashboard 读取"""
LOG_FILE = "output/train_log.jsonl"
def on_train_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
os.makedirs("output", exist_ok=True)
# 清空旧日志,新训练从头记
open(self.LOG_FILE, "w").close()
def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
if logs is None:
return
record = {
"step": state.global_step,
"epoch": round(state.epoch or 0, 3),
"loss": logs.get("loss"),
"lr": logs.get("learning_rate"),
"accuracy": logs.get("mean_token_accuracy"), # trl 0.29+ 自动计算
}
with open(self.LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
# ─── 设备 ──────────────────────────────────────────────────────────────────────
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"使用设备:{DEVICE}")
# ─── 加载数据 ──────────────────────────────────────────────────────────────────
def load_data(path: str) -> Dataset:
with open(path, "r", encoding="utf-8") as f:
records = json.load(f)
print(f"加载训练数据:{len(records)} 条")
return Dataset.from_list(records)
def format_sample(sample: dict, tokenizer) -> str:
"""
按 Qwen2.5 Chat 模板格式化样本
格式:<|im_start|>system\n...<|im_end|>\n<|im_start|>user\n问题<|im_end|>\n<|im_start|>assistant\n答案<|im_end|>
"""
system_msg = "你是明略科技的可信 AI 专家,熟悉明略的使命、愿景、产品和可信 AI 理念。请用准确、专业、清晰的中文回答问题。"
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": sample["instruction"]},
{"role": "assistant", "content": sample["output"]},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return text
# ─── 主训练逻辑 ────────────────────────────────────────────────────────────────
def train():
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 1. 加载 tokenizer 和模型
print(f"\n加载模型:{MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # MPS 用 float32
trust_remote_code=True,
)
model.enable_input_require_grads()
n_params = sum(p.numel() for p in model.parameters())
print(f"原始模型参数量:{n_params / 1e6:.1f}M")
# 2. 配置 LoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
# Qwen2.5 的注意力和 MLP 投影层
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 3. 加载并格式化数据
raw_dataset = load_data(DATA_PATH)
formatted_texts = [format_sample(s, tokenizer) for s in raw_dataset]
train_dataset = Dataset.from_dict({"text": formatted_texts})
# 4. 训练配置(trl 0.9+ 用 SFTConfig 代替 TrainingArguments)
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
lr_scheduler_type="cosine",
warmup_steps=10,
logging_steps=5,
save_steps=50,
save_total_limit=2,
fp16=False, # MPS 不支持 fp16
bf16=False,
dataloader_pin_memory=False,
report_to="none",
optim="adamw_torch",
dataset_text_field="text",
max_length=MAX_SEQ_LEN,
)
# 5. 训练
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
callbacks=[JsonlLogCallback()],
)
print(f"\n开始训练(epochs={NUM_EPOCHS},有效 batch={BATCH_SIZE * GRAD_ACCUM})...")
trainer.train()
# 6. 保存 LoRA 权重
lora_save_path = os.path.join(OUTPUT_DIR, "lora_weights")
model.save_pretrained(lora_save_path)
tokenizer.save_pretrained(lora_save_path)
print(f"\nLoRA 权重保存到:{lora_save_path}")
print("训练完成!运行 python eval.py 评估效果。")
if __name__ == "__main__":
train()