Skip to content

Commit 779ccf2

Browse files
authored
[BREAKING] Refactor Scheduler and GRPOTrainer for Flexible Multi-Turn Training (#5307)
* wip * wip * revert prompt_ids * remove tokenizer in reward * encode ids * wip replace ids * fix adv * wip * wip * wip * wip * wip * wip * refactor v1 * rename completion id * fix typo & bugs * compute loss for dynamic batch size * fix tiny bugs * dynamic rollout advantages * fix score_completions * fix split mini batch * docstring for split mini batches * fix gather device * fix rollout async infer * thinking tips scheduler * resolve dynamic sampling" * wip for chunk loss * version2 * fix gemini * batch metrics * fix merge ouput * remove tests * fix server rollout & same prompt bewteen process * fix * merge main * update * revert chmod * update docs * revert make docs * update images * update images * global inputs for reward model * pass loss scale * tool call scheduler * move toolcall scheduler to external plugin * update deepeyes * update toolcall scheduler * update deepeyes script * fix script * use safer ast literal_eval * compatible with sp * fix advantages & sort outputs * fix sp * get trajectory inputs * lint * multi turn reward example * update multi turn docs * restrict rollout async engine * update docs * check dynamic num and simplify logic for normal training * flag dynamic num samples * fix docstring typo * fix chunked inputs * log last turn metrics * last turn metrics * more profiling * fix multi turn script * fix docstring * fix engine * fix log completion * exp link for script * log num_turns * fix args * align num of device of script to exp
1 parent 3497f86 commit 779ccf2

File tree

38 files changed

+3075
-1199
lines changed

38 files changed

+3075
-1199
lines changed

docs/resources/grpo_multi_turn.png

1.45 MB
Loading

docs/source/Instruction/GRPO/DeveloperGuide/GYM环境训练.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# GYM环境训练
22

3-
注意:该 feature 需要使用 ms-swift>=3.7 且目前仅支持纯文本模型
3+
**注意** GYM环境训练逻辑已在 ms-swift 3.8 中进行重构,如果您的 ms-swift 版本低于该版本,请参考对应版本的文档。
44

55
## Gym接口
66

@@ -105,12 +105,17 @@ RolloutResponseChoice(
105105
messages=None)
106106
"""
107107
```
108-
`rollout` 命令中使用参数 `use_gym_env` 来指定使用gym作为训练的环境接口
108+
GYM环境训练可以视作一种特殊的多轮训练,区别在于使用GYM环境训练,奖励信息通过环境直接获取。
109+
110+
`rollout` 命令中使用参数 `use_gym_env` 来指定使用gym作为训练的环境接口。我们提供了兼容GYM环境的多轮规划器参考实现,见[内置多轮调度器实现](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)中的 GymScheduler 类
111+
112+
109113
```bash
110114
CUDA_VISIBLE_DEVICES=0 \
111115
swift rollout \
112116
--model xxx \
113117
--use_gym_env true \
118+
--multi_turn_scheduler gym_scheduler \
114119
--max_turns xxx
115120
```
116121

@@ -133,14 +138,11 @@ swift rollout \
133138
```json
134139
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}],"env_config":{"name":"custom_env","other_config":"xxxx"},"ctx_config":{"name":"custom_ctx","other_config":"xxxx"}}
135140
```
136-
2. gym 环境目前仅兼容纯文本模型和 AsyncEngine
141+
2. 默认仅对最后一轮response进行训练,如果gym涉及到多轮response生成,使用参数`--loss_scale default`对所有轮次的response进行训练,具体参考[文档](./多轮训练.md#损失掩码)
137142

138-
3. 默认仅对最后一轮response进行训练,如果gym涉及到多轮response生成,使用参数`--loss_scale default`对所有轮次的response进行训练,具体参考[文档](./多轮训练.md#损失掩码)
139-
140-
4. 数据流程
143+
3. 数据流程
141144
整个gym数据流程如下:
142145
<img src="../../../../resources/gym_env.png" width="400" />
143146

144-
145-
5. 奖励日志
147+
4. 奖励日志
146148
由于gym的奖励是在step函数内计算完成,所以需要手动通过`info`返回日志,最终的记录会放在completions.jsonl中的`trajectory_infos`字段.
Lines changed: 122 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,81 @@
11
# 多轮训练
22

3-
注意:该 feature 需要使用 ms-swift>=3.6
3+
**注意** 多轮训练逻辑已在 ms-swift 3.8 中进行重构,如果您的 ms-swift 版本低于该版本,请参考对应版本的文档。
44

5-
在强化学习训练场景中,模型采样可能需要与环境进行多轮交互(如工具调用、外部API访问等)。这种交互式训练要求模型能够根据环境反馈信息进行连续推理。本文档将详细介绍如何在 GRPO 训练中自定义多轮训练流程。
5+
在强化学习训练场景中,模型采样可能需要与环境进行多轮交互(如工具调用)。这种交互式训练要求模型能够根据环境反馈信息进行连续推理。本文档将详细介绍如何在 GRPO 训练中自定义多轮训练流程。
66

7+
以下是多轮训练示例图,模型可能涉及多轮 rollout,包括环境交互、工具调用等步骤:
78

8-
根据环境反馈插入方式不同,多轮可以分为:
9-
10-
- 新一轮推理:环境反馈结果作为 query,模型进行新一轮对话轮次进行响应
11-
- 当轮续写:环境反馈结果插入模型当前回复中,模型在此基础上继续续写后续内容
12-
13-
14-
我们可以自定义并通过参数 `multi_turn_scheduler` 设置多轮采样的规划器来实现多轮采样逻辑
15-
```
16-
--multi_turn_scheduler xxx
17-
--max_turns xxx
18-
```
19-
两种方式的实现例子可以参考[最佳实践](#最佳实践)
9+
![多轮示例图](../../../../resources/grpo_multi_turn.png)
2010

2111
## 多轮规划器 MultiTurnScheduler
22-
多轮规划器是多轮训练的核心组件,其工作流程如下图所示:
2312

13+
`MultiTurnScheduler` 是一个抽象基类,提供了默认的多轮对话管理逻辑,其工作流程如下图所示:
2414

2515
<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/multiturn_pipeline.png " width="300" />
2616

17+
多轮规划器主要承担两大核心功能:
18+
- **终止条件判断**:通过 `check_finished` 方法判断当前轮次推理是否应该结束
19+
- **推理请求构造**:通过 `step` 方法构建下一轮推理的请求对象
2720

28-
多轮规划器主要承担两大功能:
29-
- 终止条件判断:通过 check_finished 方法判断当前轮次推理是否应该结束
30-
- 推理请求构造:通过 step 方法构建下一轮推理的请求对象
21+
抽象基类 `MultiTurnScheduler` 的核心方法如下:
3122

32-
抽象基类 MultiTurnScheduler 代码如下
3323
```python
3424
class MultiTurnScheduler(ABC):
3525

3626
def __init__(self, max_turns: Optional[int] = None, *args, **kwargs):
3727
self.max_turns = max_turns
3828

39-
@abstractmethod
40-
def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
41-
current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]:
42-
pass
43-
44-
def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
29+
def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
30+
current_turn: int) -> Dict:
31+
"""
32+
处理对话轮次之间的转换。
33+
34+
Args:
35+
infer_request: 当前推理请求
36+
response_choice: 当前轮次的响应
37+
current_turn: 当前轮次数
38+
39+
Returns:
40+
Dict[str, Any]: 包含推理结果的字典,结构如下:
41+
- infer_request (必需): 下一轮的推理请求对象
42+
- response_token_ids (可选): 每个 rollout 轮次的响应 token IDs
43+
- response_loss_mask (可选): 每个 rollout 轮次响应的损失掩码
44+
- rollout_infos (可选): 额外信息数据
45+
"""
46+
raise NotImplementedError
47+
48+
def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
4549
current_turn: int) -> bool:
46-
if result.finish_reason == 'length':
50+
"""
51+
检查多轮 rollout 是否应该结束的默认终止逻辑。
52+
53+
默认终止条件:
54+
1. 当响应达到长度限制时 (finish_reason == 'length')
55+
2. 当对话达到最大轮数时 (如果设置了 max_turns)
56+
57+
Args:
58+
infer_request: 推理请求对象
59+
response_choice: 包含生成结果的响应选择,包括 finish_reason
60+
current_turn: 当前对话轮数
61+
62+
Returns:
63+
bool: True 表示终止对话,False 表示继续
64+
"""
65+
if response_choice.finish_reason == 'length':
4766
return True
4867
if self.max_turns and current_turn >= self.max_turns:
4968
return True
5069
return False
5170
```
5271

53-
> 如果你想要奖励函数获取多轮交互中的信息,请在 step 方法中返回额外的 dict 对象, 在奖励函数中的 kwargs中,获取 `multi_turn_infos`
72+
`step``check_finished` 方法接收的参数说明:
73+
- **infer_request**: 当前的推理请求
74+
- **response_choice**: 当前轮次的推理结果
75+
- **current_turn**: 当前推理轮次(从 1 开始)
5476

55-
```python
56-
class Scheduler():
57-
def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
58-
current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]:
59-
...
60-
return infer_request, extra_dict
77+
<details><summary>入参示例(点击展开)</summary>
6178

62-
class RewardFunction():
63-
def __call__(self, completions, **kwargs):
64-
infos = kwargs.get('multi_turn_infos', {})
65-
...
66-
```
67-
68-
step 和 check_finished 方法接收参数:
69-
- infer_request: 上轮的推理请求,包括
70-
- `messages` 键包含了模型的交互历史(注意:已包括当前模型推理结果)
71-
- 多模态信息,如 `images`
72-
- `data_dict` 包含了数据集中的其他列
73-
- result: 上轮的推理结果,
74-
- current_turn: 当前推理轮次 (从1开始)
75-
76-
入参示例
7779
```python
7880
infer_request
7981
"""
@@ -93,9 +95,9 @@ RolloutInferRequest(
9395
}
9496
)
9597
"""
96-
result
98+
response_choice
9799
"""
98-
RolloutResponseChoice(
100+
ChatCompletionResponseChoice(
99101
index=0,
100102
message=ChatMessage(
101103
role='assistant',
@@ -104,78 +106,110 @@ RolloutResponseChoice(
104106
logprobs=None,
105107
messages=None)
106108
"""
107-
# result.messages will be copied at the end of multi-turn inference.
109+
# response_choice.messages will be copied at the end of multi-turn inference.
108110
```
111+
</details>
109112

110-
默认的 `check_finished` 逻辑会在两种情况下停止推理
113+
<br>
114+
<br>
111115

116+
默认的 `check_finished` 逻辑会在以下两种情况下停止推理:
112117
- 模型回复被截断,即超出了 `max_completion_length`
113118
- 模型推理轮数超出了限制的最大轮数
114119

120+
完整的默认多轮 rollout 逻辑请参考该类的 `run` 方法,我们也可以通过重载`run` 方法来实现自定义多轮逻辑。
115121

116-
推荐使用 AsyncEngine 来实现高效的批量数据异步多轮采样(只支持 external server mode),AsyncEngine 在多轮推理时能够减小推理过程中的计算气泡(如图)
117-
118-
<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/asyncengine.png" width="400" />
122+
## 设置多轮训练参数
119123

124+
在 swift rollout 命令中,设置 multi_turn_scheduler 参数指定规划器
120125

121-
`rollout` 命令中使用参数 `use_async_engine` 来指定engine的种类
122126
```bash
123-
CUDA_VISIBLE_DEVICES=0 \
124127
swift rollout \
125-
--model xxx \
128+
--model Qwen/Qwen3-1.7B \
126129
--use_async_engine true \
127-
--multi_turn_scheduler xxx \
128-
--max_turns xxx
130+
--multi_turn_scheduler thinking_tips_scheduler \
131+
--vllm_max_model_len 32768 \
132+
--vllm_gpu_memory_utilization 0.8 \
133+
--max_turns 3
129134
```
130135

131-
通过参数`external_plugins`, 我们可以将本地的多轮规划器注册进 ms-swift 中,具体实现参考[代码](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
132136

133-
多轮训练脚本参考
137+
> 通过参数 `external_plugins`,我们可以将本地的多轮规划器注册到 ms-swift 中,具体实现请参考[代码](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
134138
135-
- [server mode](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/external/vllm_multi_turn.sh)
136-
- [colocate mode](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/vllm_multi_turn.sh)
139+
多轮训练脚本请参考[脚本](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/external/vllm_multi_turn.sh)
137140

138141

139-
## 最佳实践
140-
[插件代码示例](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)中提供了两种多轮规划器的例子,实现在数学问题中提示模型再次思考并给出答案,分别对应两种多轮推理:
142+
对于多轮 rollout,我们使用 AsyncEngine 来实现高效的批量数据异步多轮采样。AsyncEngine 在多轮推理时能够减少推理过程中的计算气泡:
141143

142-
- 第一种方式(新一轮推理):新插入一轮对话,提示模型的答案错误,需要重新思考(math_tip_trick_multi_turn)
143-
- 第二种方式(续写):回溯到模型的思考阶段,并加入思考错误的提示 (math_tip_trick)
144+
<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/asyncengine.png" width="400" />
144145

146+
`rollout` 命令中使用参数 `use_async_engine` 来指定 engine 的种类(默认使用 async engine):
145147

146-
## 注意事项
147148

148-
### 奖励函数
149-
注意在奖励函数中,接受的 `completions` 参数为最后一轮模型回复,如果奖励函数需要根据模型多轮回复计算奖励,需要获取 `messages` 键来获取完整的多轮对话记录
149+
## 高级设置
150150

151-
```python
152-
class Reward(ORM):
151+
### 自定义多轮交互逻辑
152+
在以上默认逻辑中,我们用一条轨迹来计算多轮 rollout 的损失,这里需要假设多轮交互的过程中,模型的历史信息没有收到改变。
153153

154-
def __call__(completions, **kwargs):
155-
print(kwargs.keys())
156-
# dict_keys(['problem', 'solution', 'messages', 'is_truncated'])
157-
messages = kwargs.get('messages')
158-
...
159-
```
154+
而在一些多轮场景中,我们可以需要在多轮 rollout 过程中动态地修改模型的历史信息(比如压缩历史信息),此时,我们需要将每轮的 rollout 单独作为一条轨迹进行训练。
160155

156+
比较常见的一种场景是对于思考类模型,在实际推理过程中,模型通常只会保留最后一轮的思考内容,而忽略历史模型回复中的思考内容。
157+
158+
对于这类场景,我们需要重写多轮规划器中的交互逻辑,即重载 `run` 方法,从而单独返回每一轮的 Rollout 的结果。
159+
160+
框架内置的 `ThinkingModelTipsScheduler` 类展示了如何通过重写 `run()` 方法来实现完全自定义的多轮推理逻辑。请参考[内置多轮调度器实现](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)
161+
162+
**注意**: 这种情况下,相同轨迹的数据会拆分为多条数据,在奖励相关的处理中,需要对相同轨迹的数据分配同样的reward。
163+
164+
可以在kwargs中获取 trajectory_inputs 获取完整轨迹的数据,具体实现参考[MultiTurnThinkingTips类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
165+
166+
### 返回 response token ids
167+
在默认的多轮交互流程中,规划器先把模型生成的文本字符串返回给 trainer,trainer 再将其重新 encode 为 token id,用于后续训练。为了避免这一步重复编码的开销,你可以让规划器直接返回 response_token_ids,省去 trainer 侧的再次 encode。
168+
169+
具体做法如下:
170+
171+
- 在 response_choice 对象中读取 token_ids 属性,即可获得本次 rollout 生成的 token 序列。
172+
- 在 step/run 方法的返回值里加入 response_token_ids,trainer 便能直接使用这些 token id 参与训练,无需重新编码。
173+
174+
具体实现可以参考[ThinkingModelTipsScheduler](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)
161175

162176
### 损失掩码
163177

164178
在工具调用或环境交互返回结果时,若需将返回内容作为模型响应的一部分,建议对这些插入内容进行掩码处理,以确保模型在训练过程中不会对这些外部生成的内容计算损失。
165179

166-
这里需要通过设置参数 loss_scale ,实现自定义掩码逻辑,具体参考[定制化loss_scale文档](../../../Customization/插件化.md#定制化loss_scale)
180+
我们可以通过两种方式设置损失掩码
181+
182+
**第一种:设置 loss_scale**
183+
184+
ms-swift 提供 loss_scale 参数来对模型回复部分的内容进行损失缩放设置。比如设置`--loss_scale last_round`,可以将非最后一轮的模型回复的损失置零。我们也可以实现自定义 loss_scale,具体请参考[定制化 loss_scale 文档](../../../Customization/插件化.md#定制化loss_scale)
185+
186+
> 注:在GRPO中,loss_scale 只提供掩码功能,不提供缩放功能。
167187
168-
默认 loss_scale 值:
188+
**第二种:设置loss_mask**
169189

170-
grpo训练(即设置`multi_turn_scheduler`),loss_scale 默认为`default`,即对 messages 中的 每一轮 response 进行训练
171-
> 如果数据集中本身包含 assistant response 也会被计算入内,如果想要排除数据集中的response , 需要自定义 loss_scale
190+
`step`或者`run`方法中设置 response_loss_mask, 可以在规划器中自定义损失掩码。前提需要返回response token ids,返回的 response_loss_mask 需要与 response token ids等长。当返回 response_loss_mask 时,loss_scale 参数失效。
172191

173-
如果只想只计算最后一轮 response(rollout结果)损失,请修改为`last_round`
192+
response_loss_mask 返回可以参考[ToolCallScheduler类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
174193

194+
### 奖励函数相关
175195

176-
注意 loss_scale 可以用于
196+
在奖励函数中获取多轮 Rollout 中的信息
197+
198+
`step`或者`run`方法中,返回 `rollout_infos` 对象,在奖励函数的 kwargs 中获取 `rollout_infos`
199+
200+
```python
201+
class Scheduler():
202+
def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
203+
current_turn: int) -> Dict:
204+
...
205+
return {'infer_request': infer_request, 'rollout_infos': extra_dict}
206+
207+
class RewardFunction():
208+
def __call__(self, completions, **kwargs):
209+
infos = kwargs.get('rollout_infos', {})
210+
...
211+
```
177212

178-
1. 标注需要训练的 tokens (0为不训练)
179-
2. 放缩 tokens 的训练权重
213+
### 在 Scheduler 中获取额外的数据集信息
180214

181-
而 GRPO 中暂不支持 loss_scale 的权重设置
215+
在训练侧设置参数`--vllm_server_pass_dataset`,可将数据集中的其他列传入多轮规划器。在`infer_request.data_dict`中获取

docs/source/Instruction/GRPO/GetStarted/GRPO.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,13 @@ swift rollout \
234234
- {reward_func_name}:特定奖励
235235
- entropy:entropy token 均值,在设置`log_entropy`时记录
236236

237-
设置 `report_to wandb/swanlab` 将训练动态推送到对应的平台
237+
设置 `report_to wandb/swanlab` 将训练动态Table推送到对应的平台
238+
239+
如果需要在Table中额外记录其他列,请在 `GRPOTrainer._generate_and_score_completions` 方法中,设置 metrics_to_gather 字典。
240+
241+
默认自动检测
242+
- `image`:视觉数据集图像输入。(暂时只支持wandb)
243+
- `solution`:数据集中的 solution 列。
238244

239245
## FAQ
240246
**1. 训练过程中 loss 等于0 / 接近0 / 小于0**

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ reward模型参数将在PPO、GRPO中使用。
511511
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用。
512512
- vllm_server_port vLLM server 服务端口,默认为8000。
513513
- vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。
514+
- vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。
514515
- async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`.
515516
- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。)
516517
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
522522
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
523523
- vllm_server_port: The service port of the vLLM server. Default is 8000.
524524
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds.
525+
- vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training.
525526
- async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`.
526527
- vllm_mode colocate parameter (For more parameter support, refer to the [vLLM Arguments](#vLLM-Arguments).)
527528
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.

0 commit comments

Comments
 (0)