diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 3e1a2bd2a..746a7430a 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -17,14 +17,8 @@ from torch import Tensor from vllm import RequestOutput -from fairseq2.data import ( - CollateOptionsOverride, - Collater, - SequenceData, -) -from fairseq2.datasets import ( - SequenceBatch, -) +from fairseq2.data import CollateOptionsOverride, Collater, SequenceData +from fairseq2.datasets import SequenceBatch from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gang, Gangs @@ -365,6 +359,55 @@ def combine_prompts_responses_for_scoring( return responses +def get_vllm_logprobs( + vllm_outputs: List[RequestOutput], + gangs, + rollout_start_end: tuple[int, int] | None = None, +): + """Compute per-token logprobs for selected continuations across a list of requests. + + For each RequestOutput (one prompt) and each of its sampled continuations we + concatenate the prompt logprobs (skipping the first entry) with the generation + logprobs. All resulting sequences are then right-padded with 0.0 to the global + maximum length and stacked into a single tensor. + + Parameters + ---------- + vllm_outputs: + List of vLLM RequestOutput objects (one per prompt). + gangs: + Fairseq2 gangs object (unused, kept for parity/extensibility). + rollout_start_end: + Optional (start, end) slice specifying which continuation indices to include + per prompt (used for micro-batching when forward_group_size < group_size). + + Returns + ------- + Tensor + Shape ``(num_selected_continuations, max_seq_len)`` with 0.0 padding. + """ + sequences: List[Tensor] = [] + for request in vllm_outputs: + prompt_logprobs = [ + list(d.values())[0].logprob for d in request.prompt_logprobs[1:] + ] + outputs = request.outputs + if rollout_start_end is not None: # micro-batching + s, e = rollout_start_end + outputs = outputs[s:e] + for output in outputs: + gen_logprobs = [list(d.values())[0].logprob for d in output.logprobs] + seq = torch.tensor(prompt_logprobs + gen_logprobs) + sequences.append(seq) + + max_len = max(t.size(0) for t in sequences) + padded = torch.zeros(len(sequences), max_len) + for i, t in enumerate(sequences): + padded[i, : t.size(0)] = t + + return padded + + def convert_vllm_output_to_ref_score(vllm_outputs: List[RequestOutput], gangs): ref_scores = [] for req_output in vllm_outputs: diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 58cd226a7..f7ff7c2fe 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -17,9 +17,7 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from fairseq2.context import RuntimeContext -from fairseq2.datasets import ( - SequenceBatch, -) +from fairseq2.datasets import SequenceBatch from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gang, Gangs @@ -38,6 +36,7 @@ compute_token_level_entropy, generate_rollouts, get_rollout_lengths, + get_vllm_logprobs, log_rollouts, update_avg_reward, update_avg_reward_len_norm, @@ -77,6 +76,7 @@ def prepare_grpo_batch( reward_output: dict, gangs: Gang, rollout_start_end: tuple[int], + adv_std_normalization: bool, ): prompt_rollouts = [] @@ -107,9 +107,12 @@ def prepare_grpo_batch( # gangs.root.barrier() rewards = torch.tensor(rewards, device=gangs.dp.device).float() # [Batch, Rollouts] - rewards_normalized = (rewards - rewards.mean(dim=1, keepdim=True)) / ( - rewards.std(dim=1, keepdim=True) + 1e-6 - ) # small epsilon to compensate 0 std + + rewards_normalized = rewards - rewards.mean(dim=1, keepdim=True) + if adv_std_normalization: # normalize advantages with std + rewards_normalized = rewards_normalized / ( + rewards.std(dim=1, keepdim=True) + 1e-6 + ) # small epsilon to compensate 0 std rewards_normalized = rewards_normalized[ :, rollout_start_end[0] : rollout_start_end[1] @@ -279,6 +282,7 @@ def __call__( rollout_start_end=self._rollout_bag.get_rollout_start_end( self._config.loss_config.forward_group_size ), + adv_std_normalization=self._config.loss_config.adv_std_normalization, ) # grpo_batch, reward_output = self._reward.prepare_grpo_batch(prompt_batch, rollouts) # loss_zeroer is used when entire batch has no valid prefrence pair @@ -296,7 +300,20 @@ def __call__( grpo_input_batch_seqs, grpo_input_batch_seqs_layout ) - logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch) + model_logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch) + rollout_window = self._rollout_bag.get_rollout_start_end( + self._config.loss_config.forward_group_size + ) + vllm_logps = get_vllm_logprobs( + rollouts, self._gangs, rollout_start_end=rollout_window + ).to(model_logps.device) + + if vllm_logps.size(0) != model_logps.size(0): + raise RuntimeError( + "Mismatch between vLLM and model logprobs row counts after slicing: " + f"model={model_logps.size(0)}, vllm={vllm_logps.size(0)}. " + "Ensure rollout slicing aligns with forward_group_size and group_size." + ) tgt_logit_entropy = compute_token_level_entropy( grpo_model_logits, grpo_target_batch.target_mask @@ -312,16 +329,21 @@ def __call__( prompt_rollout_seqs, prompt_rollout_layout, ) = grpo_batch.prompt_rollouts.as_input() - ref_logps = compute_reference_logps( - self._gangs, - self._reference_model, - prompt_rollout_seqs, - prompt_rollout_layout, - grpo_batch.prompt_lengths, - ) - _grpo_objective = self._compute_grpo_objective( - logps, ref_logps, grpo_batch.rewards, grpo_target_batch + # if beta > 0, compute reference logprobs + if self._config.loss_config.beta > 0: + ref_logps = compute_reference_logps( + self._gangs, + self._reference_model, + prompt_rollout_seqs, + prompt_rollout_layout, + grpo_batch.prompt_lengths, + ) + else: + ref_logps = None + + _grpo_objective, total_tokens = self._compute_grpo_objective( + model_logps, vllm_logps, ref_logps, grpo_batch.rewards, grpo_target_batch ) grpo_loss = -_grpo_objective + max_entropy_regularizer @@ -352,7 +374,10 @@ def __call__( loss = grpo_loss - return loss, prompt_batch.batch_size + if self._config.loss_config.loss_token_mean: + return loss, total_tokens + else: + return loss, prompt_batch.batch_size def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor: assert target.target_mask is not None @@ -365,7 +390,8 @@ def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor: def _compute_grpo_objective( self, - logps, + model_logps, + vllm_logps, ref_logps, advantages: Tensor, # outcome based only for now target_batch: SequenceBatch, @@ -373,25 +399,41 @@ def _compute_grpo_objective( batch_size = advantages.size(0) num_rollouts = advantages.size(1) - logps = logps.view(batch_size, num_rollouts, -1) - ref_logps = ref_logps.view(batch_size, num_rollouts, -1) + model_logps = model_logps.view(batch_size, num_rollouts, -1) + vllm_logps = vllm_logps.view(batch_size, num_rollouts, -1) - # kl penalty - kl = (ref_logps - logps).exp() - (ref_logps - logps) - 1.0 + per_token_scaled_advantage = ( + model_logps - model_logps.detach() + ).exp() * advantages[:, :, None] - per_token_scaled_advantage = (logps - logps.detach()).exp() * advantages[ - :, :, None - ] - # per_token_scaled_advantage = logps * advantages[:,:,None] + if self._config.loss_config.tis_imp_ratio_cap > 0: + tis_imp_ratio = torch.exp(model_logps - vllm_logps) + tis_imp_ratio = torch.clamp( + tis_imp_ratio, max=self._config.loss_config.tis_imp_ratio_cap + ) + per_token_scaled_advantage = per_token_scaled_advantage * tis_imp_ratio - per_token_loss = per_token_scaled_advantage - self._config.loss_config.beta * kl + if self._config.loss_config.beta > 0: + ref_logps = ref_logps.view(batch_size, num_rollouts, -1) + + # kl penalty + kl = (ref_logps - model_logps).exp() - (ref_logps - model_logps) - 1.0 + per_token_loss = ( + per_token_scaled_advantage - self._config.loss_config.beta * kl + ) + else: + per_token_loss = per_token_scaled_advantage target_mask = target_batch.target_mask.view(batch_size, num_rollouts, -1) + total_tokens = target_mask.sum().item() + if self._config.loss_config.length_normalization: per_seq_loss = ( (per_token_loss * target_mask).sum(dim=-1) / target_mask.sum(dim=-1) ).mean(dim=1) + elif self._config.loss_config.loss_token_mean: + per_seq_loss = per_token_loss * target_mask else: per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1) @@ -401,7 +443,7 @@ def _compute_grpo_objective( # self._gangs.root.barrier() - return per_seq_loss.sum() + return per_seq_loss.sum(), total_tokens @override def set_step_nr(self, step_nr: int) -> None: @@ -442,12 +484,21 @@ class GrpoLossConfig: length_normalization: bool = True """If True, normalize loss by sequence length. If False, use sequence-level loss.""" + adv_std_normalization: bool = True + """If True, normalize advantages with standard deviation.""" + log_rollouts: bool = False """Log sample rollouts during training/validation.""" + loss_token_mean: bool = False + """If True, average loss over tokens. If False, sum over tokens.""" + validation_vllm_sampling_params: Dict[str, Any] = field(default_factory=lambda: {}) """VLLM sampling params for validation. If empty, training params will be used.""" + tis_imp_ratio_cap: float = 2.0 + """Maximum cap for the truncated importance sampling ratio. If <= 0, no cap is applied.""" + @dataclass(kw_only=True) class GrpoFinetuneConfig: @@ -500,7 +551,7 @@ def create( log.info(f"GRPO loss config:\n{config}") reference_model = vllm_actors[config.vllm_reference_model_actor_name] - if config.vllm_sync.sync_ref_model_every_n_steps != -1: + if reference_model and config.vllm_sync.sync_ref_model_every_n_steps != -1: if reference_model and reference_model.update_process_groups is None: raise ValueError( f"Reference model actor must have update process group if we sync weights"