1
1
from contextlib import nullcontext
2
- from typing import Any , Dict , Optional
2
+ from typing import Any , Optional
3
3
4
4
import ray
5
5
import torch
6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
10
- from coati .distributed .reward .verifiable_reward import VerifiableReward
11
- from coati .distributed .utils import calc_action_log_probs
9
+ from coati .distributed .utils import memory_efficient_logprob
12
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
13
11
from transformers import AutoModelForCausalLM , AutoTokenizer
14
12
@@ -40,6 +38,8 @@ def __init__(
40
38
project_name : str = None ,
41
39
run_name : str = None ,
42
40
wandb_group_name : str = None ,
41
+ enable_profiling : bool = False ,
42
+ n_behind : int = 0 ,
43
43
):
44
44
print (f"Using GRPO config: { grpo_config } " )
45
45
if (
@@ -62,12 +62,15 @@ def __init__(
62
62
batch_size ,
63
63
model_config ,
64
64
plugin_config ,
65
+ generate_config ,
65
66
minibatch_size ,
66
67
save_interval = save_interval ,
67
68
save_dir = save_dir ,
69
+ enable_profiling = enable_profiling ,
70
+ n_behind = n_behind ,
68
71
)
69
- path = model_config .pop ("path" )
70
- self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
72
+ self . path = model_config .pop ("path" )
73
+ self .policy_model = AutoModelForCausalLM .from_pretrained (self . path , ** model_config )
71
74
self .policy_model .train ()
72
75
self .policy_model .gradient_checkpointing_enable ()
73
76
self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
@@ -95,12 +98,7 @@ def __init__(
95
98
loss_variation = grpo_config .get ("loss_variation" , "sample_level" ),
96
99
)
97
100
98
- # Reference model is initialized from policy model.
99
- if self .policy_loss_fn .beta > 0 :
100
- self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
101
- self .reference_model .eval ()
102
-
103
- self .tokenizer = AutoTokenizer .from_pretrained (path )
101
+ self .tokenizer = AutoTokenizer .from_pretrained (self .path )
104
102
self .pad_token_id = self .tokenizer .pad_token_id
105
103
self .num_generations = num_generations
106
104
self .filter_range = grpo_config .get ("filter_range" , None )
@@ -119,20 +117,7 @@ def __init__(
119
117
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120
118
)
121
119
# Initialize verifiable reward.
122
- response_format_tags = grpo_config .get ("response_format_tags" , None )
123
- reward_model_kwargs = {
124
- k : v
125
- for k , v in grpo_config .items ()
126
- if k in ["soft_over_length_punishment" , "max_new_tokens" , "cache_length" ]
127
- }
128
- self .reward_model = VerifiableReward (
129
- reward_fns = [
130
- math_reward_fn if grpo_config .get ("reward_fn_type" ) == "think_answer_tags" else boxed_math_reward_fn
131
- ],
132
- tokenizer = self .tokenizer ,
133
- tags = response_format_tags ,
134
- ** reward_model_kwargs ,
135
- )
120
+ grpo_config .get ("response_format_tags" , None )
136
121
self .global_step = 0
137
122
138
123
self .lr_scheduler = CosineAnnealingWarmupLR (
@@ -158,7 +143,10 @@ def setup(self):
158
143
self .policy_model , self .optimizer , _ , _ , self .lr_scheduler = self .booster .boost (
159
144
self .policy_model , self .optimizer , lr_scheduler = self .lr_scheduler
160
145
)
146
+ # Reference model is initialized from policy model.
161
147
if self .policy_loss_fn .beta > 0 :
148
+ self .reference_model = AutoModelForCausalLM .from_pretrained (self .path , ** self .model_config )
149
+ self .reference_model .eval ()
162
150
self .reference_model , * _ = self .booster .boost (self .reference_model )
163
151
self .plugin .logger .set_level ("ERROR" )
164
152
@@ -295,12 +283,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295
283
)
296
284
297
285
if self .booster .plugin .stage_manager .is_last_stage ():
298
- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
299
- reference_action_log_probs = calc_action_log_probs (
300
- reference_model_logits / self .generate_config ["temperature" ],
286
+ reference_action_log_probs = memory_efficient_logprob (
287
+ reference_model_outputs ["outputs" ]["logits" ],
301
288
input_ids_forward_micro_batch ,
302
289
num_action ,
303
- self .plugin .shard_config ,
290
+ shard_config = self .plugin .shard_config ,
304
291
)
305
292
else :
306
293
# Dummy reference logprobs for data iterator.
@@ -323,11 +310,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323
310
324
311
def _criterion (outputs , inputs ):
325
312
action_logits = outputs .logits
326
- action_log_probs = calc_action_log_probs (
327
- action_logits / self . generate_config [ "temperature" ] ,
313
+ action_log_probs = memory_efficient_logprob (
314
+ action_logits ,
328
315
inputs ["input_ids" ],
329
316
num_action ,
330
- self .plugin .shard_config ,
317
+ shard_config = self .plugin .shard_config ,
331
318
)
332
319
if "reference_action_log_probs" in inputs :
333
320
per_token_kl = (
@@ -370,16 +357,15 @@ def _criterion(outputs, inputs):
370
357
mean_kl .append (kl )
371
358
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372
359
else :
373
-
374
360
policy_model_logits = self .policy_model (
375
361
input_ids = input_ids_forward_micro_batch ,
376
362
attention_mask = attention_mask_forward_micro_batch ,
377
363
).logits
378
- action_log_probs = calc_action_log_probs (
364
+ action_log_probs = memory_efficient_logprob (
379
365
policy_model_logits / self .generate_config ["temperature" ],
380
366
input_ids_forward_micro_batch ,
381
367
num_action ,
382
- self .plugin .shard_config ,
368
+ shard_config = self .plugin .shard_config ,
383
369
)
384
370
385
371
if self .policy_loss_fn .beta > 0 :
@@ -388,11 +374,11 @@ def _criterion(outputs, inputs):
388
374
input_ids = input_ids_forward_micro_batch ,
389
375
attention_mask = attention_mask_forward_micro_batch ,
390
376
).logits
391
- reference_action_log_probs = calc_action_log_probs (
377
+ reference_action_log_probs = memory_efficient_logprob (
392
378
reference_model_logits / self .generate_config ["temperature" ],
393
379
input_ids_forward_micro_batch ,
394
380
num_action ,
395
- self .plugin .shard_config ,
381
+ shard_config = self .plugin .shard_config ,
396
382
)
397
383
per_token_kl = (
398
384
torch .exp (reference_action_log_probs - action_log_probs )
@@ -498,40 +484,6 @@ def _criterion(outputs, inputs):
498
484
else :
499
485
return None
500
486
501
- def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
502
- """
503
- Calculate the group reward for the given rollout group.
504
-
505
- Args:
506
- rollout_group (Dict[str, Any]):
507
- a group of samples generated by the model from the same prompt
508
- contain the following keys:
509
- "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510
- "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511
- "action_mask": torch.Tensor, [num_of_generation, response_length]
512
- "action_log_probs": torch.Tensor, [num_of_generation, response_length]
513
- "response_idx": int, torch.Tensor, [num_of_generation, 2]
514
- "gt_answer": torch.Tensor, [num_of_generation, 128]
515
- "temperature": torch.Tensor, [] (scalar)
516
-
517
- Returns:
518
- Dict[str, Any]: The new group data with calculated reward.
519
- """
520
- reward_model_output = self .reward_model (
521
- rollout ["input_ids" ],
522
- gt_answer = rollout ["gt_answer" ],
523
- response_idx = rollout ["response_idx" ],
524
- )
525
- # [num_of_generation]
526
- reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
527
- format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
528
- ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
529
-
530
- rollout ["reward" ] = reward .view ((- 1 , 1 ))
531
- rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
532
- rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
533
- return rollout
534
-
535
487
def state_dict (self ):
536
488
self .policy_model ._force_wait_all_gather ()
537
489
model = self .policy_model .unwrap ()
0 commit comments