11from contextlib import nullcontext
2- from typing import Any , Dict , Optional
2+ from typing import Any , Optional
33
44import ray
55import torch
66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
10- from coati .distributed .reward .verifiable_reward import VerifiableReward
11- from coati .distributed .utils import calc_action_log_probs
9+ from coati .distributed .utils import memory_efficient_logprob
1210from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1311from transformers import AutoModelForCausalLM , AutoTokenizer
1412
@@ -40,6 +38,8 @@ def __init__(
4038 project_name : str = None ,
4139 run_name : str = None ,
4240 wandb_group_name : str = None ,
41+ enable_profiling : bool = False ,
42+ n_behind : int = 0 ,
4343 ):
4444 print (f"Using GRPO config: { grpo_config } " )
4545 if (
@@ -62,12 +62,15 @@ def __init__(
6262 batch_size ,
6363 model_config ,
6464 plugin_config ,
65+ generate_config ,
6566 minibatch_size ,
6667 save_interval = save_interval ,
6768 save_dir = save_dir ,
69+ enable_profiling = enable_profiling ,
70+ n_behind = n_behind ,
6871 )
69- path = model_config .pop ("path" )
70- self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
72+ self . path = model_config .pop ("path" )
73+ self .policy_model = AutoModelForCausalLM .from_pretrained (self . path , ** model_config )
7174 self .policy_model .train ()
7275 self .policy_model .gradient_checkpointing_enable ()
7376 self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
@@ -95,12 +98,7 @@ def __init__(
9598 loss_variation = grpo_config .get ("loss_variation" , "sample_level" ),
9699 )
97100
98- # Reference model is initialized from policy model.
99- if self .policy_loss_fn .beta > 0 :
100- self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
101- self .reference_model .eval ()
102-
103- self .tokenizer = AutoTokenizer .from_pretrained (path )
101+ self .tokenizer = AutoTokenizer .from_pretrained (self .path )
104102 self .pad_token_id = self .tokenizer .pad_token_id
105103 self .num_generations = num_generations
106104 self .filter_range = grpo_config .get ("filter_range" , None )
@@ -119,20 +117,7 @@ def __init__(
119117 "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120118 )
121119 # Initialize verifiable reward.
122- response_format_tags = grpo_config .get ("response_format_tags" , None )
123- reward_model_kwargs = {
124- k : v
125- for k , v in grpo_config .items ()
126- if k in ["soft_over_length_punishment" , "max_new_tokens" , "cache_length" ]
127- }
128- self .reward_model = VerifiableReward (
129- reward_fns = [
130- math_reward_fn if grpo_config .get ("reward_fn_type" ) == "think_answer_tags" else boxed_math_reward_fn
131- ],
132- tokenizer = self .tokenizer ,
133- tags = response_format_tags ,
134- ** reward_model_kwargs ,
135- )
120+ grpo_config .get ("response_format_tags" , None )
136121 self .global_step = 0
137122
138123 self .lr_scheduler = CosineAnnealingWarmupLR (
@@ -158,7 +143,10 @@ def setup(self):
158143 self .policy_model , self .optimizer , _ , _ , self .lr_scheduler = self .booster .boost (
159144 self .policy_model , self .optimizer , lr_scheduler = self .lr_scheduler
160145 )
146+ # Reference model is initialized from policy model.
161147 if self .policy_loss_fn .beta > 0 :
148+ self .reference_model = AutoModelForCausalLM .from_pretrained (self .path , ** self .model_config )
149+ self .reference_model .eval ()
162150 self .reference_model , * _ = self .booster .boost (self .reference_model )
163151 self .plugin .logger .set_level ("ERROR" )
164152
@@ -295,12 +283,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295283 )
296284
297285 if self .booster .plugin .stage_manager .is_last_stage ():
298- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
299- reference_action_log_probs = calc_action_log_probs (
300- reference_model_logits / self .generate_config ["temperature" ],
286+ reference_action_log_probs = memory_efficient_logprob (
287+ reference_model_outputs ["outputs" ]["logits" ],
301288 input_ids_forward_micro_batch ,
302289 num_action ,
303- self .plugin .shard_config ,
290+ shard_config = self .plugin .shard_config ,
304291 )
305292 else :
306293 # Dummy reference logprobs for data iterator.
@@ -323,11 +310,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323310
324311 def _criterion (outputs , inputs ):
325312 action_logits = outputs .logits
326- action_log_probs = calc_action_log_probs (
327- action_logits / self . generate_config [ "temperature" ] ,
313+ action_log_probs = memory_efficient_logprob (
314+ action_logits ,
328315 inputs ["input_ids" ],
329316 num_action ,
330- self .plugin .shard_config ,
317+ shard_config = self .plugin .shard_config ,
331318 )
332319 if "reference_action_log_probs" in inputs :
333320 per_token_kl = (
@@ -370,16 +357,15 @@ def _criterion(outputs, inputs):
370357 mean_kl .append (kl )
371358 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372359 else :
373-
374360 policy_model_logits = self .policy_model (
375361 input_ids = input_ids_forward_micro_batch ,
376362 attention_mask = attention_mask_forward_micro_batch ,
377363 ).logits
378- action_log_probs = calc_action_log_probs (
364+ action_log_probs = memory_efficient_logprob (
379365 policy_model_logits / self .generate_config ["temperature" ],
380366 input_ids_forward_micro_batch ,
381367 num_action ,
382- self .plugin .shard_config ,
368+ shard_config = self .plugin .shard_config ,
383369 )
384370
385371 if self .policy_loss_fn .beta > 0 :
@@ -388,11 +374,11 @@ def _criterion(outputs, inputs):
388374 input_ids = input_ids_forward_micro_batch ,
389375 attention_mask = attention_mask_forward_micro_batch ,
390376 ).logits
391- reference_action_log_probs = calc_action_log_probs (
377+ reference_action_log_probs = memory_efficient_logprob (
392378 reference_model_logits / self .generate_config ["temperature" ],
393379 input_ids_forward_micro_batch ,
394380 num_action ,
395- self .plugin .shard_config ,
381+ shard_config = self .plugin .shard_config ,
396382 )
397383 per_token_kl = (
398384 torch .exp (reference_action_log_probs - action_log_probs )
@@ -498,40 +484,6 @@ def _criterion(outputs, inputs):
498484 else :
499485 return None
500486
501- def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
502- """
503- Calculate the group reward for the given rollout group.
504-
505- Args:
506- rollout_group (Dict[str, Any]):
507- a group of samples generated by the model from the same prompt
508- contain the following keys:
509- "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510- "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511- "action_mask": torch.Tensor, [num_of_generation, response_length]
512- "action_log_probs": torch.Tensor, [num_of_generation, response_length]
513- "response_idx": int, torch.Tensor, [num_of_generation, 2]
514- "gt_answer": torch.Tensor, [num_of_generation, 128]
515- "temperature": torch.Tensor, [] (scalar)
516-
517- Returns:
518- Dict[str, Any]: The new group data with calculated reward.
519- """
520- reward_model_output = self .reward_model (
521- rollout ["input_ids" ],
522- gt_answer = rollout ["gt_answer" ],
523- response_idx = rollout ["response_idx" ],
524- )
525- # [num_of_generation]
526- reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
527- format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
528- ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
529-
530- rollout ["reward" ] = reward .view ((- 1 , 1 ))
531- rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
532- rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
533- return rollout
534-
535487 def state_dict (self ):
536488 self .policy_model ._force_wait_all_gather ()
537489 model = self .policy_model .unwrap ()
0 commit comments