1717from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
1818
1919from fairseq2 .context import RuntimeContext
20- from fairseq2 .datasets import (
21- SequenceBatch ,
22- )
20+ from fairseq2 .datasets import SequenceBatch
2321from fairseq2 .datasets .preference import PreferenceBatch
2422from fairseq2 .datasets .prompt import PromptBatch
2523from fairseq2 .gang import Gang , Gangs
3836 compute_token_level_entropy ,
3937 generate_rollouts ,
4038 get_rollout_lengths ,
39+ get_vllm_logprobs ,
4140 log_rollouts ,
4241 update_avg_reward ,
4342 update_avg_reward_len_norm ,
@@ -77,6 +76,7 @@ def prepare_grpo_batch(
7776 reward_output : dict ,
7877 gangs : Gang ,
7978 rollout_start_end : tuple [int ],
79+ adv_std_normalization : bool ,
8080):
8181
8282 prompt_rollouts = []
@@ -107,9 +107,12 @@ def prepare_grpo_batch(
107107 # gangs.root.barrier()
108108
109109 rewards = torch .tensor (rewards , device = gangs .dp .device ).float () # [Batch, Rollouts]
110- rewards_normalized = (rewards - rewards .mean (dim = 1 , keepdim = True )) / (
111- rewards .std (dim = 1 , keepdim = True ) + 1e-6
112- ) # small epsilon to compensate 0 std
110+
111+ rewards_normalized = rewards - rewards .mean (dim = 1 , keepdim = True )
112+ if adv_std_normalization : # normalize advantages with std
113+ rewards_normalized = rewards_normalized / (
114+ rewards .std (dim = 1 , keepdim = True ) + 1e-6
115+ ) # small epsilon to compensate 0 std
113116
114117 rewards_normalized = rewards_normalized [
115118 :, rollout_start_end [0 ] : rollout_start_end [1 ]
@@ -279,6 +282,7 @@ def __call__(
279282 rollout_start_end = self ._rollout_bag .get_rollout_start_end (
280283 self ._config .loss_config .forward_group_size
281284 ),
285+ adv_std_normalization = self ._config .loss_config .adv_std_normalization ,
282286 )
283287
284288 # grpo_batch, reward_output = self._reward.prepare_grpo_batch(prompt_batch, rollouts) # loss_zeroer is used when entire batch has no valid prefrence pair
@@ -296,7 +300,20 @@ def __call__(
296300 grpo_input_batch_seqs , grpo_input_batch_seqs_layout
297301 )
298302
299- logps = self ._gather_lprobs (grpo_model_logits , grpo_target_batch )
303+ model_logps = self ._gather_lprobs (grpo_model_logits , grpo_target_batch )
304+ rollout_window = self ._rollout_bag .get_rollout_start_end (
305+ self ._config .loss_config .forward_group_size
306+ )
307+ vllm_logps = get_vllm_logprobs (
308+ rollouts , self ._gangs , rollout_start_end = rollout_window
309+ ).to (model_logps .device )
310+
311+ if vllm_logps .size (0 ) != model_logps .size (0 ):
312+ raise RuntimeError (
313+ "Mismatch between vLLM and model logprobs row counts after slicing: "
314+ f"model={ model_logps .size (0 )} , vllm={ vllm_logps .size (0 )} . "
315+ "Ensure rollout slicing aligns with forward_group_size and group_size."
316+ )
300317
301318 tgt_logit_entropy = compute_token_level_entropy (
302319 grpo_model_logits , grpo_target_batch .target_mask
@@ -312,16 +329,21 @@ def __call__(
312329 prompt_rollout_seqs ,
313330 prompt_rollout_layout ,
314331 ) = grpo_batch .prompt_rollouts .as_input ()
315- ref_logps = compute_reference_logps (
316- self ._gangs ,
317- self ._reference_model ,
318- prompt_rollout_seqs ,
319- prompt_rollout_layout ,
320- grpo_batch .prompt_lengths ,
321- )
322332
323- _grpo_objective = self ._compute_grpo_objective (
324- logps , ref_logps , grpo_batch .rewards , grpo_target_batch
333+ # if beta > 0, compute reference logprobs
334+ if self ._config .loss_config .beta > 0 :
335+ ref_logps = compute_reference_logps (
336+ self ._gangs ,
337+ self ._reference_model ,
338+ prompt_rollout_seqs ,
339+ prompt_rollout_layout ,
340+ grpo_batch .prompt_lengths ,
341+ )
342+ else :
343+ ref_logps = None
344+
345+ _grpo_objective , total_tokens = self ._compute_grpo_objective (
346+ model_logps , vllm_logps , ref_logps , grpo_batch .rewards , grpo_target_batch
325347 )
326348
327349 grpo_loss = - _grpo_objective + max_entropy_regularizer
@@ -352,7 +374,10 @@ def __call__(
352374
353375 loss = grpo_loss
354376
355- return loss , prompt_batch .batch_size
377+ if self ._config .loss_config .loss_token_mean :
378+ return loss , total_tokens
379+ else :
380+ return loss , prompt_batch .batch_size
356381
357382 def _gather_lprobs (self , logits : Tensor , target : SequenceBatch ) -> Tensor :
358383 assert target .target_mask is not None
@@ -365,33 +390,50 @@ def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:
365390
366391 def _compute_grpo_objective (
367392 self ,
368- logps ,
393+ model_logps ,
394+ vllm_logps ,
369395 ref_logps ,
370396 advantages : Tensor , # outcome based only for now
371397 target_batch : SequenceBatch ,
372398 ) -> tuple [Tensor , Tensor , Tensor ]:
373399
374400 batch_size = advantages .size (0 )
375401 num_rollouts = advantages .size (1 )
376- logps = logps .view (batch_size , num_rollouts , - 1 )
377- ref_logps = ref_logps .view (batch_size , num_rollouts , - 1 )
402+ model_logps = model_logps .view (batch_size , num_rollouts , - 1 )
403+ vllm_logps = vllm_logps .view (batch_size , num_rollouts , - 1 )
378404
379- # kl penalty
380- kl = (ref_logps - logps ).exp () - (ref_logps - logps ) - 1.0
405+ per_token_scaled_advantage = (
406+ model_logps - model_logps .detach ()
407+ ).exp () * advantages [:, :, None ]
381408
382- per_token_scaled_advantage = (logps - logps .detach ()).exp () * advantages [
383- :, :, None
384- ]
385- # per_token_scaled_advantage = logps * advantages[:,:,None]
409+ if self ._config .loss_config .tis_imp_ratio_cap > 0 :
410+ tis_imp_ratio = torch .exp (model_logps - vllm_logps )
411+ tis_imp_ratio = torch .clamp (
412+ tis_imp_ratio , max = self ._config .loss_config .tis_imp_ratio_cap
413+ )
414+ per_token_scaled_advantage = per_token_scaled_advantage * tis_imp_ratio
386415
387- per_token_loss = per_token_scaled_advantage - self ._config .loss_config .beta * kl
416+ if self ._config .loss_config .beta > 0 :
417+ ref_logps = ref_logps .view (batch_size , num_rollouts , - 1 )
418+
419+ # kl penalty
420+ kl = (ref_logps - model_logps ).exp () - (ref_logps - model_logps ) - 1.0
421+ per_token_loss = (
422+ per_token_scaled_advantage - self ._config .loss_config .beta * kl
423+ )
424+ else :
425+ per_token_loss = per_token_scaled_advantage
388426
389427 target_mask = target_batch .target_mask .view (batch_size , num_rollouts , - 1 )
390428
429+ total_tokens = target_mask .sum ().item ()
430+
391431 if self ._config .loss_config .length_normalization :
392432 per_seq_loss = (
393433 (per_token_loss * target_mask ).sum (dim = - 1 ) / target_mask .sum (dim = - 1 )
394434 ).mean (dim = 1 )
435+ elif self ._config .loss_config .loss_token_mean :
436+ per_seq_loss = per_token_loss * target_mask
395437 else :
396438 per_seq_loss = ((per_token_loss * target_mask ).sum (dim = - 1 )).mean (dim = 1 )
397439
@@ -401,7 +443,7 @@ def _compute_grpo_objective(
401443
402444 # self._gangs.root.barrier()
403445
404- return per_seq_loss .sum ()
446+ return per_seq_loss .sum (), total_tokens
405447
406448 @override
407449 def set_step_nr (self , step_nr : int ) -> None :
@@ -442,12 +484,21 @@ class GrpoLossConfig:
442484 length_normalization : bool = True
443485 """If True, normalize loss by sequence length. If False, use sequence-level loss."""
444486
487+ adv_std_normalization : bool = True
488+ """If True, normalize advantages with standard deviation."""
489+
445490 log_rollouts : bool = False
446491 """Log sample rollouts during training/validation."""
447492
493+ loss_token_mean : bool = False
494+ """If True, average loss over tokens. If False, sum over tokens."""
495+
448496 validation_vllm_sampling_params : Dict [str , Any ] = field (default_factory = lambda : {})
449497 """VLLM sampling params for validation. If empty, training params will be used."""
450498
499+ tis_imp_ratio_cap : float = 2.0
500+ """Maximum cap for the truncated importance sampling ratio. If <= 0, no cap is applied."""
501+
451502
452503@dataclass (kw_only = True )
453504class GrpoFinetuneConfig :
@@ -500,7 +551,7 @@ def create(
500551 log .info (f"GRPO loss config:\n { config } " )
501552
502553 reference_model = vllm_actors [config .vllm_reference_model_actor_name ]
503- if config .vllm_sync .sync_ref_model_every_n_steps != - 1 :
554+ if reference_model and config .vllm_sync .sync_ref_model_every_n_steps != - 1 :
504555 if reference_model and reference_model .update_process_groups is None :
505556 raise ValueError (
506557 f"Reference model actor must have update process group if we sync weights"
0 commit comments