Skip to content

Commit 606459b

Browse files
jacklanchantinJack Lanchantin
andauthored
Add truncated importance sampling and DrGRPO args (#1394)
* drgrpo * get vllm logps * Update _wandb.py * remove beta check * format * revert * add importance sampling correction * dont run ref model forward if beta==0 * add tis ratio clamp = 2 * clean up * configs * clean up * default * var name * var name * only use tis_imp_ratio_cap * revert unrelated files * clean up * fix type hint * black/isort * Allow batched inputs for get_vllm_logprobs * allow batch_sz > 1 * Modify condition for reference log probabilities * fix batch>1, microbatching --------- Co-authored-by: Jack Lanchantin <[email protected]>
1 parent 221c83b commit 606459b

File tree

2 files changed

+131
-37
lines changed

2 files changed

+131
-37
lines changed

src/fairseq2/recipes/lm/_online_finetune/_common.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,8 @@
1717
from torch import Tensor
1818
from vllm import RequestOutput
1919

20-
from fairseq2.data import (
21-
CollateOptionsOverride,
22-
Collater,
23-
SequenceData,
24-
)
25-
from fairseq2.datasets import (
26-
SequenceBatch,
27-
)
20+
from fairseq2.data import CollateOptionsOverride, Collater, SequenceData
21+
from fairseq2.datasets import SequenceBatch
2822
from fairseq2.datasets.preference import PreferenceBatch
2923
from fairseq2.datasets.prompt import PromptBatch
3024
from fairseq2.gang import Gang, Gangs
@@ -365,6 +359,55 @@ def combine_prompts_responses_for_scoring(
365359
return responses
366360

367361

362+
def get_vllm_logprobs(
363+
vllm_outputs: List[RequestOutput],
364+
gangs,
365+
rollout_start_end: tuple[int, int] | None = None,
366+
):
367+
"""Compute per-token logprobs for selected continuations across a list of requests.
368+
369+
For each RequestOutput (one prompt) and each of its sampled continuations we
370+
concatenate the prompt logprobs (skipping the first entry) with the generation
371+
logprobs. All resulting sequences are then right-padded with 0.0 to the global
372+
maximum length and stacked into a single tensor.
373+
374+
Parameters
375+
----------
376+
vllm_outputs:
377+
List of vLLM RequestOutput objects (one per prompt).
378+
gangs:
379+
Fairseq2 gangs object (unused, kept for parity/extensibility).
380+
rollout_start_end:
381+
Optional (start, end) slice specifying which continuation indices to include
382+
per prompt (used for micro-batching when forward_group_size < group_size).
383+
384+
Returns
385+
-------
386+
Tensor
387+
Shape ``(num_selected_continuations, max_seq_len)`` with 0.0 padding.
388+
"""
389+
sequences: List[Tensor] = []
390+
for request in vllm_outputs:
391+
prompt_logprobs = [
392+
list(d.values())[0].logprob for d in request.prompt_logprobs[1:]
393+
]
394+
outputs = request.outputs
395+
if rollout_start_end is not None: # micro-batching
396+
s, e = rollout_start_end
397+
outputs = outputs[s:e]
398+
for output in outputs:
399+
gen_logprobs = [list(d.values())[0].logprob for d in output.logprobs]
400+
seq = torch.tensor(prompt_logprobs + gen_logprobs)
401+
sequences.append(seq)
402+
403+
max_len = max(t.size(0) for t in sequences)
404+
padded = torch.zeros(len(sequences), max_len)
405+
for i, t in enumerate(sequences):
406+
padded[i, : t.size(0)] = t
407+
408+
return padded
409+
410+
368411
def convert_vllm_output_to_ref_score(vllm_outputs: List[RequestOutput], gangs):
369412
ref_scores = []
370413
for req_output in vllm_outputs:

src/fairseq2/recipes/lm/_online_finetune/_grpo.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
1818

1919
from fairseq2.context import RuntimeContext
20-
from fairseq2.datasets import (
21-
SequenceBatch,
22-
)
20+
from fairseq2.datasets import SequenceBatch
2321
from fairseq2.datasets.preference import PreferenceBatch
2422
from fairseq2.datasets.prompt import PromptBatch
2523
from fairseq2.gang import Gang, Gangs
@@ -38,6 +36,7 @@
3836
compute_token_level_entropy,
3937
generate_rollouts,
4038
get_rollout_lengths,
39+
get_vllm_logprobs,
4140
log_rollouts,
4241
update_avg_reward,
4342
update_avg_reward_len_norm,
@@ -77,6 +76,7 @@ def prepare_grpo_batch(
7776
reward_output: dict,
7877
gangs: Gang,
7978
rollout_start_end: tuple[int],
79+
adv_std_normalization: bool,
8080
):
8181

8282
prompt_rollouts = []
@@ -107,9 +107,12 @@ def prepare_grpo_batch(
107107
# gangs.root.barrier()
108108

109109
rewards = torch.tensor(rewards, device=gangs.dp.device).float() # [Batch, Rollouts]
110-
rewards_normalized = (rewards - rewards.mean(dim=1, keepdim=True)) / (
111-
rewards.std(dim=1, keepdim=True) + 1e-6
112-
) # small epsilon to compensate 0 std
110+
111+
rewards_normalized = rewards - rewards.mean(dim=1, keepdim=True)
112+
if adv_std_normalization: # normalize advantages with std
113+
rewards_normalized = rewards_normalized / (
114+
rewards.std(dim=1, keepdim=True) + 1e-6
115+
) # small epsilon to compensate 0 std
113116

114117
rewards_normalized = rewards_normalized[
115118
:, rollout_start_end[0] : rollout_start_end[1]
@@ -279,6 +282,7 @@ def __call__(
279282
rollout_start_end=self._rollout_bag.get_rollout_start_end(
280283
self._config.loss_config.forward_group_size
281284
),
285+
adv_std_normalization=self._config.loss_config.adv_std_normalization,
282286
)
283287

284288
# grpo_batch, reward_output = self._reward.prepare_grpo_batch(prompt_batch, rollouts) # loss_zeroer is used when entire batch has no valid prefrence pair
@@ -296,7 +300,20 @@ def __call__(
296300
grpo_input_batch_seqs, grpo_input_batch_seqs_layout
297301
)
298302

299-
logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch)
303+
model_logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch)
304+
rollout_window = self._rollout_bag.get_rollout_start_end(
305+
self._config.loss_config.forward_group_size
306+
)
307+
vllm_logps = get_vllm_logprobs(
308+
rollouts, self._gangs, rollout_start_end=rollout_window
309+
).to(model_logps.device)
310+
311+
if vllm_logps.size(0) != model_logps.size(0):
312+
raise RuntimeError(
313+
"Mismatch between vLLM and model logprobs row counts after slicing: "
314+
f"model={model_logps.size(0)}, vllm={vllm_logps.size(0)}. "
315+
"Ensure rollout slicing aligns with forward_group_size and group_size."
316+
)
300317

301318
tgt_logit_entropy = compute_token_level_entropy(
302319
grpo_model_logits, grpo_target_batch.target_mask
@@ -312,16 +329,21 @@ def __call__(
312329
prompt_rollout_seqs,
313330
prompt_rollout_layout,
314331
) = grpo_batch.prompt_rollouts.as_input()
315-
ref_logps = compute_reference_logps(
316-
self._gangs,
317-
self._reference_model,
318-
prompt_rollout_seqs,
319-
prompt_rollout_layout,
320-
grpo_batch.prompt_lengths,
321-
)
322332

323-
_grpo_objective = self._compute_grpo_objective(
324-
logps, ref_logps, grpo_batch.rewards, grpo_target_batch
333+
# if beta > 0, compute reference logprobs
334+
if self._config.loss_config.beta > 0:
335+
ref_logps = compute_reference_logps(
336+
self._gangs,
337+
self._reference_model,
338+
prompt_rollout_seqs,
339+
prompt_rollout_layout,
340+
grpo_batch.prompt_lengths,
341+
)
342+
else:
343+
ref_logps = None
344+
345+
_grpo_objective, total_tokens = self._compute_grpo_objective(
346+
model_logps, vllm_logps, ref_logps, grpo_batch.rewards, grpo_target_batch
325347
)
326348

327349
grpo_loss = -_grpo_objective + max_entropy_regularizer
@@ -352,7 +374,10 @@ def __call__(
352374

353375
loss = grpo_loss
354376

355-
return loss, prompt_batch.batch_size
377+
if self._config.loss_config.loss_token_mean:
378+
return loss, total_tokens
379+
else:
380+
return loss, prompt_batch.batch_size
356381

357382
def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:
358383
assert target.target_mask is not None
@@ -365,33 +390,50 @@ def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:
365390

366391
def _compute_grpo_objective(
367392
self,
368-
logps,
393+
model_logps,
394+
vllm_logps,
369395
ref_logps,
370396
advantages: Tensor, # outcome based only for now
371397
target_batch: SequenceBatch,
372398
) -> tuple[Tensor, Tensor, Tensor]:
373399

374400
batch_size = advantages.size(0)
375401
num_rollouts = advantages.size(1)
376-
logps = logps.view(batch_size, num_rollouts, -1)
377-
ref_logps = ref_logps.view(batch_size, num_rollouts, -1)
402+
model_logps = model_logps.view(batch_size, num_rollouts, -1)
403+
vllm_logps = vllm_logps.view(batch_size, num_rollouts, -1)
378404

379-
# kl penalty
380-
kl = (ref_logps - logps).exp() - (ref_logps - logps) - 1.0
405+
per_token_scaled_advantage = (
406+
model_logps - model_logps.detach()
407+
).exp() * advantages[:, :, None]
381408

382-
per_token_scaled_advantage = (logps - logps.detach()).exp() * advantages[
383-
:, :, None
384-
]
385-
# per_token_scaled_advantage = logps * advantages[:,:,None]
409+
if self._config.loss_config.tis_imp_ratio_cap > 0:
410+
tis_imp_ratio = torch.exp(model_logps - vllm_logps)
411+
tis_imp_ratio = torch.clamp(
412+
tis_imp_ratio, max=self._config.loss_config.tis_imp_ratio_cap
413+
)
414+
per_token_scaled_advantage = per_token_scaled_advantage * tis_imp_ratio
386415

387-
per_token_loss = per_token_scaled_advantage - self._config.loss_config.beta * kl
416+
if self._config.loss_config.beta > 0:
417+
ref_logps = ref_logps.view(batch_size, num_rollouts, -1)
418+
419+
# kl penalty
420+
kl = (ref_logps - model_logps).exp() - (ref_logps - model_logps) - 1.0
421+
per_token_loss = (
422+
per_token_scaled_advantage - self._config.loss_config.beta * kl
423+
)
424+
else:
425+
per_token_loss = per_token_scaled_advantage
388426

389427
target_mask = target_batch.target_mask.view(batch_size, num_rollouts, -1)
390428

429+
total_tokens = target_mask.sum().item()
430+
391431
if self._config.loss_config.length_normalization:
392432
per_seq_loss = (
393433
(per_token_loss * target_mask).sum(dim=-1) / target_mask.sum(dim=-1)
394434
).mean(dim=1)
435+
elif self._config.loss_config.loss_token_mean:
436+
per_seq_loss = per_token_loss * target_mask
395437
else:
396438
per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1)
397439

@@ -401,7 +443,7 @@ def _compute_grpo_objective(
401443

402444
# self._gangs.root.barrier()
403445

404-
return per_seq_loss.sum()
446+
return per_seq_loss.sum(), total_tokens
405447

406448
@override
407449
def set_step_nr(self, step_nr: int) -> None:
@@ -442,12 +484,21 @@ class GrpoLossConfig:
442484
length_normalization: bool = True
443485
"""If True, normalize loss by sequence length. If False, use sequence-level loss."""
444486

487+
adv_std_normalization: bool = True
488+
"""If True, normalize advantages with standard deviation."""
489+
445490
log_rollouts: bool = False
446491
"""Log sample rollouts during training/validation."""
447492

493+
loss_token_mean: bool = False
494+
"""If True, average loss over tokens. If False, sum over tokens."""
495+
448496
validation_vllm_sampling_params: Dict[str, Any] = field(default_factory=lambda: {})
449497
"""VLLM sampling params for validation. If empty, training params will be used."""
450498

499+
tis_imp_ratio_cap: float = 2.0
500+
"""Maximum cap for the truncated importance sampling ratio. If <= 0, no cap is applied."""
501+
451502

452503
@dataclass(kw_only=True)
453504
class GrpoFinetuneConfig:
@@ -500,7 +551,7 @@ def create(
500551
log.info(f"GRPO loss config:\n{config}")
501552

502553
reference_model = vllm_actors[config.vllm_reference_model_actor_name]
503-
if config.vllm_sync.sync_ref_model_every_n_steps != -1:
554+
if reference_model and config.vllm_sync.sync_ref_model_every_n_steps != -1:
504555
if reference_model and reference_model.update_process_groups is None:
505556
raise ValueError(
506557
f"Reference model actor must have update process group if we sync weights"

0 commit comments

Comments
 (0)