Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@
from torch import Tensor
from vllm import RequestOutput

from fairseq2.data import (
CollateOptionsOverride,
Collater,
SequenceData,
)
from fairseq2.datasets import (
SequenceBatch,
)
from fairseq2.data import CollateOptionsOverride, Collater, SequenceData
from fairseq2.datasets import SequenceBatch
from fairseq2.datasets.preference import PreferenceBatch
from fairseq2.datasets.prompt import PromptBatch
from fairseq2.gang import Gang, Gangs
Expand Down Expand Up @@ -365,6 +359,55 @@ def combine_prompts_responses_for_scoring(
return responses


def get_vllm_logprobs(
vllm_outputs: List[RequestOutput],
gangs,
rollout_start_end: tuple[int, int] | None = None,
):
"""Compute per-token logprobs for selected continuations across a list of requests.

For each RequestOutput (one prompt) and each of its sampled continuations we
concatenate the prompt logprobs (skipping the first entry) with the generation
logprobs. All resulting sequences are then right-padded with 0.0 to the global
maximum length and stacked into a single tensor.

Parameters
----------
vllm_outputs:
List of vLLM RequestOutput objects (one per prompt).
gangs:
Fairseq2 gangs object (unused, kept for parity/extensibility).
rollout_start_end:
Optional (start, end) slice specifying which continuation indices to include
per prompt (used for micro-batching when forward_group_size < group_size).

Returns
-------
Tensor
Shape ``(num_selected_continuations, max_seq_len)`` with 0.0 padding.
"""
sequences: List[Tensor] = []
for request in vllm_outputs:
prompt_logprobs = [
list(d.values())[0].logprob for d in request.prompt_logprobs[1:]
]
outputs = request.outputs
if rollout_start_end is not None: # micro-batching
s, e = rollout_start_end
outputs = outputs[s:e]
for output in outputs:
gen_logprobs = [list(d.values())[0].logprob for d in output.logprobs]
seq = torch.tensor(prompt_logprobs + gen_logprobs)
sequences.append(seq)

max_len = max(t.size(0) for t in sequences)
padded = torch.zeros(len(sequences), max_len)
for i, t in enumerate(sequences):
padded[i, : t.size(0)] = t

return padded


def convert_vllm_output_to_ref_score(vllm_outputs: List[RequestOutput], gangs):
ref_scores = []
for req_output in vllm_outputs:
Expand Down
109 changes: 80 additions & 29 deletions src/fairseq2/recipes/lm/_online_finetune/_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator

from fairseq2.context import RuntimeContext
from fairseq2.datasets import (
SequenceBatch,
)
from fairseq2.datasets import SequenceBatch
from fairseq2.datasets.preference import PreferenceBatch
from fairseq2.datasets.prompt import PromptBatch
from fairseq2.gang import Gang, Gangs
Expand All @@ -38,6 +36,7 @@
compute_token_level_entropy,
generate_rollouts,
get_rollout_lengths,
get_vllm_logprobs,
log_rollouts,
update_avg_reward,
update_avg_reward_len_norm,
Expand Down Expand Up @@ -77,6 +76,7 @@ def prepare_grpo_batch(
reward_output: dict,
gangs: Gang,
rollout_start_end: tuple[int],
adv_std_normalization: bool,
):

prompt_rollouts = []
Expand Down Expand Up @@ -107,9 +107,12 @@ def prepare_grpo_batch(
# gangs.root.barrier()

rewards = torch.tensor(rewards, device=gangs.dp.device).float() # [Batch, Rollouts]
rewards_normalized = (rewards - rewards.mean(dim=1, keepdim=True)) / (
rewards.std(dim=1, keepdim=True) + 1e-6
) # small epsilon to compensate 0 std

rewards_normalized = rewards - rewards.mean(dim=1, keepdim=True)
if adv_std_normalization: # normalize advantages with std
rewards_normalized = rewards_normalized / (
rewards.std(dim=1, keepdim=True) + 1e-6
) # small epsilon to compensate 0 std

rewards_normalized = rewards_normalized[
:, rollout_start_end[0] : rollout_start_end[1]
Expand Down Expand Up @@ -279,6 +282,7 @@ def __call__(
rollout_start_end=self._rollout_bag.get_rollout_start_end(
self._config.loss_config.forward_group_size
),
adv_std_normalization=self._config.loss_config.adv_std_normalization,
)

# grpo_batch, reward_output = self._reward.prepare_grpo_batch(prompt_batch, rollouts) # loss_zeroer is used when entire batch has no valid prefrence pair
Expand All @@ -296,7 +300,20 @@ def __call__(
grpo_input_batch_seqs, grpo_input_batch_seqs_layout
)

logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch)
model_logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch)
rollout_window = self._rollout_bag.get_rollout_start_end(
self._config.loss_config.forward_group_size
)
vllm_logps = get_vllm_logprobs(
rollouts, self._gangs, rollout_start_end=rollout_window
).to(model_logps.device)

if vllm_logps.size(0) != model_logps.size(0):
raise RuntimeError(
"Mismatch between vLLM and model logprobs row counts after slicing: "
f"model={model_logps.size(0)}, vllm={vllm_logps.size(0)}. "
"Ensure rollout slicing aligns with forward_group_size and group_size."
)

tgt_logit_entropy = compute_token_level_entropy(
grpo_model_logits, grpo_target_batch.target_mask
Expand All @@ -312,16 +329,21 @@ def __call__(
prompt_rollout_seqs,
prompt_rollout_layout,
) = grpo_batch.prompt_rollouts.as_input()
ref_logps = compute_reference_logps(
self._gangs,
self._reference_model,
prompt_rollout_seqs,
prompt_rollout_layout,
grpo_batch.prompt_lengths,
)

_grpo_objective = self._compute_grpo_objective(
logps, ref_logps, grpo_batch.rewards, grpo_target_batch
# if beta > 0, compute reference logprobs
if self._config.loss_config.beta > 0:
ref_logps = compute_reference_logps(
self._gangs,
self._reference_model,
prompt_rollout_seqs,
prompt_rollout_layout,
grpo_batch.prompt_lengths,
)
else:
ref_logps = None

_grpo_objective, total_tokens = self._compute_grpo_objective(
model_logps, vllm_logps, ref_logps, grpo_batch.rewards, grpo_target_batch
)

grpo_loss = -_grpo_objective + max_entropy_regularizer
Expand Down Expand Up @@ -352,7 +374,10 @@ def __call__(

loss = grpo_loss

return loss, prompt_batch.batch_size
if self._config.loss_config.loss_token_mean:
return loss, total_tokens
else:
return loss, prompt_batch.batch_size

def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:
assert target.target_mask is not None
Expand All @@ -365,33 +390,50 @@ def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:

def _compute_grpo_objective(
self,
logps,
model_logps,
vllm_logps,
ref_logps,
advantages: Tensor, # outcome based only for now
target_batch: SequenceBatch,
) -> tuple[Tensor, Tensor, Tensor]:

batch_size = advantages.size(0)
num_rollouts = advantages.size(1)
logps = logps.view(batch_size, num_rollouts, -1)
ref_logps = ref_logps.view(batch_size, num_rollouts, -1)
model_logps = model_logps.view(batch_size, num_rollouts, -1)
vllm_logps = vllm_logps.view(batch_size, num_rollouts, -1)

# kl penalty
kl = (ref_logps - logps).exp() - (ref_logps - logps) - 1.0
per_token_scaled_advantage = (
model_logps - model_logps.detach()
).exp() * advantages[:, :, None]

per_token_scaled_advantage = (logps - logps.detach()).exp() * advantages[
:, :, None
]
# per_token_scaled_advantage = logps * advantages[:,:,None]
if self._config.loss_config.tis_imp_ratio_cap > 0:
tis_imp_ratio = torch.exp(model_logps - vllm_logps)
tis_imp_ratio = torch.clamp(
tis_imp_ratio, max=self._config.loss_config.tis_imp_ratio_cap
)
per_token_scaled_advantage = per_token_scaled_advantage * tis_imp_ratio

per_token_loss = per_token_scaled_advantage - self._config.loss_config.beta * kl
if self._config.loss_config.beta > 0:
ref_logps = ref_logps.view(batch_size, num_rollouts, -1)

# kl penalty
kl = (ref_logps - model_logps).exp() - (ref_logps - model_logps) - 1.0
per_token_loss = (
per_token_scaled_advantage - self._config.loss_config.beta * kl
)
else:
per_token_loss = per_token_scaled_advantage

target_mask = target_batch.target_mask.view(batch_size, num_rollouts, -1)

total_tokens = target_mask.sum().item()

if self._config.loss_config.length_normalization:
per_seq_loss = (
(per_token_loss * target_mask).sum(dim=-1) / target_mask.sum(dim=-1)
).mean(dim=1)
elif self._config.loss_config.loss_token_mean:
per_seq_loss = per_token_loss * target_mask
else:
per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1)

Expand All @@ -401,7 +443,7 @@ def _compute_grpo_objective(

# self._gangs.root.barrier()

return per_seq_loss.sum()
return per_seq_loss.sum(), total_tokens

@override
def set_step_nr(self, step_nr: int) -> None:
Expand Down Expand Up @@ -442,12 +484,21 @@ class GrpoLossConfig:
length_normalization: bool = True
"""If True, normalize loss by sequence length. If False, use sequence-level loss."""

adv_std_normalization: bool = True
"""If True, normalize advantages with standard deviation."""

log_rollouts: bool = False
"""Log sample rollouts during training/validation."""

loss_token_mean: bool = False
"""If True, average loss over tokens. If False, sum over tokens."""

validation_vllm_sampling_params: Dict[str, Any] = field(default_factory=lambda: {})
"""VLLM sampling params for validation. If empty, training params will be used."""

tis_imp_ratio_cap: float = 2.0
"""Maximum cap for the truncated importance sampling ratio. If <= 0, no cap is applied."""


@dataclass(kw_only=True)
class GrpoFinetuneConfig:
Expand Down Expand Up @@ -500,7 +551,7 @@ def create(
log.info(f"GRPO loss config:\n{config}")

reference_model = vllm_actors[config.vllm_reference_model_actor_name]
if config.vllm_sync.sync_ref_model_every_n_steps != -1:
if reference_model and config.vllm_sync.sync_ref_model_every_n_steps != -1:
if reference_model and reference_model.update_process_groups is None:
raise ValueError(
f"Reference model actor must have update process group if we sync weights"
Expand Down
Loading