Skip to content
Open

RLLM #1232

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6095032
skywork and some qwrn changes
Jul 22, 2025
f3d876a
Removing think tokens
Jul 22, 2025
d7e9fd0
Merge branch 'ot_merge' into swarna/skyworkv2
swarnaHub Jul 22, 2025
164458b
Fixing GRMs
Jul 23, 2025
b162c05
Merge branch 'swarna/skyworkv2' of github.com:facebookresearch/fairse…
Jul 23, 2025
4fc9aea
Black
Jul 23, 2025
a6ad0a8
Merge branch online_training
Jul 29, 2025
f829d82
Import issue
Jul 29, 2025
9392f0d
add missing sw import
Aug 6, 2025
55dc622
Different configs for pairwise GRM
Aug 7, 2025
ee17161
Minor fix and more logging
Aug 7, 2025
18ff4c4
Online dpo: pairwise GRM should sample at least two rollouts
Aug 7, 2025
585b744
zero reward for rollouts not involved in pairwise judgments
Aug 7, 2025
510bdf2
simplifying
Aug 7, 2025
38aaf53
SequenceBatch seq_lens type ensure to be a list
Aug 9, 2025
a6ab8b0
add pairwsie J1 with reference answer
Aug 13, 2025
2004533
fix None ref answer
Aug 18, 2025
bfc255b
Pairwise with pivot changes
Aug 19, 2025
1f942d7
New pivot changes + cleanup
Aug 20, 2025
5eee4ee
Fix
Aug 20, 2025
5cdb6b9
Making pair type configurable
Aug 20, 2025
e14421d
Config change
Aug 21, 2025
8831f36
update prompt
Aug 24, 2025
9fc9dbb
some more logging
Aug 27, 2025
d14fb90
Merge branch 'swarna/skyworkv2' of github.com:facebookresearch/fairse…
Aug 27, 2025
b1ba0e2
Fixing typo in comment
Aug 27, 2025
d47ef15
kwise judgment support
Sep 2, 2025
d474070
Adding support for acemath
Sep 3, 2025
1162d60
Skywork-RM from hf
Sep 4, 2025
4ea811d
add parsed ref
Sep 6, 2025
6703f1b
update prompt template
Sep 7, 2025
fe84d9e
all comparisons in k-wise
Sep 10, 2025
e7137ac
Jacklanchantin/qwen (#1260)
jacklanchantin Sep 30, 2025
dfb958a
octothinker assets
Sep 30, 2025
a746129
Changes
Sep 30, 2025
1116f02
Merging
Sep 30, 2025
9355ce6
Minor changes
Oct 16, 2025
474a537
Logging judge input
Oct 22, 2025
8211664
Tracking a second reward (for debugging)
Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/fairseq2/assets/cards/models/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,34 @@ model_arch: llama3_1_8b
checkpoint: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer_family: llama
use_v2_tokenizer: true
use_v2_tokenizer: true

---

name: octothinker_8b_hybrid
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/
tokenizer_family: llama
use_v2_tokenizer: true

---

name: octothinker_8b_long
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/
tokenizer_family: llama
use_v2_tokenizer: true

---

name: octothinker_8b_short
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/
tokenizer_family: llama
use_v2_tokenizer: true
30 changes: 30 additions & 0 deletions src/fairseq2/recipes/lm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -48,6 +48,12 @@
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1KwiseScoreExtractor as J1KwiseScoreExtractor,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1KwiseScoreExtractorHandler as J1KwiseScoreExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PointwiseExtractor as J1PointwiseExtractor,
)
Expand Down Expand Up @@ -84,6 +90,12 @@
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvGeneralVerifierPipeline as NoEnvGeneralVerifierPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvSkyworkRMPipeline as NoEnvSkyworkRMPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
RemoteModelHandler as RemoteModelHandler,
)
Expand All @@ -93,6 +105,18 @@
from fairseq2.recipes.lm._online_finetune._rewards import (
AtheneVerifierHandler as AtheneVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
SkyworkVerifier as SkyworkVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
SkyworkVerifierHandler as SkyworkVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
AceMathVerifier as AceMathVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
AceMathVerifierHandler as AceMathVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativePairwiseVerifier as GenerativePairwiseVerifier,
)
Expand All @@ -105,6 +129,12 @@
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativePointwiseVerifierHandler as GenerativePointwiseVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativeKwiseVerifier as GenerativeKwiseVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativeKwiseVerifierHandler as GenerativeKwiseVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import GSM8kVerifier as GSM8kVerifier
from fairseq2.recipes.lm._online_finetune._rewards import (
GSM8kVerifierHandler as GSM8kVerifierHandler,
Expand Down
73 changes: 64 additions & 9 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,6 +8,7 @@

import contextlib
import io
import re
from dataclasses import dataclass
from typing import List, cast

Expand All @@ -17,14 +18,8 @@
from torch import Tensor
from vllm import RequestOutput

from fairseq2.data import (
CollateOptionsOverride,
Collater,
SequenceData,
)
from fairseq2.datasets import (
SequenceBatch,
)
from fairseq2.data import CollateOptionsOverride, Collater, SequenceData
from fairseq2.datasets import SequenceBatch
from fairseq2.datasets.preference import PreferenceBatch
from fairseq2.datasets.prompt import PromptBatch
from fairseq2.gang import Gang, Gangs
Expand Down Expand Up @@ -93,9 +88,13 @@

seq_data = cast(SequenceData, collater(to_collate))

seq_lens = seq_data["seqs"]["seq_lens"]
assert isinstance(seq_lens, Tensor) or isinstance(seq_lens, list)
if isinstance(seq_lens, Tensor):
seq_lens = seq_lens.tolist()
batch = SequenceBatch(
seq_data["seqs"]["seqs"],
seq_data["seqs"]["seq_lens"],
seq_lens,
target_mask=seq_data["target_loss_mask"]["seqs"],
)
batch.to(device)
Expand Down Expand Up @@ -395,6 +394,8 @@
prompt = prompt_batch.meta_info.get("prompt_raw")[0]
elif "raw_prompt" in prompt_batch.meta_info:
prompt = prompt_batch.meta_info.get("raw_prompt")[0]
elif "problem" in prompt_batch.meta_info:
prompt = prompt_batch.meta_info.get("problem")[0]
else:
# raw text prompt doesn't exist for this dataset
prompt = "DUMMY PROMPT"
Expand All @@ -416,6 +417,48 @@
return rollout_lengths


def strip_think_tokens(rollouts: List[SequenceData]):
count_stripped, count_not_stripped, total_count, think_present = 0, 0, 0, 0
for sample in rollouts:
for rollout in sample.outputs:
rollout_text = rollout.text
if "<think>" in rollout_text:
think_present += 1
if rollout.finish_reason == "length":
count_not_stripped += 1
if rollout.finish_reason == "stop":
count_stripped += 1
total_count += 1
rollout.text = re.sub(
r"<think>.*?</think>", "", rollout_text, flags=re.DOTALL
).strip()

log.info(f"Total count: {total_count}")
log.info(f"Think present: {think_present}")
log.info(f"Count stripped: {count_stripped/total_count}")
log.info(f"Count not stripped: {count_not_stripped/total_count}")

return rollouts

def get_failed_to_parse_answers(reward_output: dict, batch_size: int):
if "answers" in reward_output:
log.info(f"Answers: {reward_output['answers']}")
failed_to_parse = sum(answer is None for rollouts in reward_output["answers"] for answer in rollouts)
return failed_to_parse/batch_size
else:
return 0.0

def strip_for_octothinker(rollouts: List[SequenceData]):
for sample in rollouts:
for rollout in sample.outputs:
rollout_text = rollout.text
if "\nUser:" in rollout_text:
rollout_text = rollout_text[:rollout_text.find("\nUser:")]
rollout.text = rollout_text

return rollouts


class StatefulRolloutBag:
"""A stateful container for managing and reusing model rollouts across multiple micro-batches.

Expand Down Expand Up @@ -504,11 +547,23 @@
@torch.inference_mode()
def update_avg_reward(metric_bag: MetricBag, avg_reward):
metric_bag.get(Mean, "avg_reward").update(avg_reward, weight=1)

@torch.inference_mode()
def update_avg_second_reward(metric_bag: MetricBag, avg_reward):
metric_bag.get(Mean, "avg_second_reward").update(avg_reward, weight=1)

@torch.inference_mode()
def update_reward_matches(metric_bag: MetricBag, reward_matches):
metric_bag.get(Mean, "reward_matches").update(reward_matches, weight=1)


@torch.inference_mode()
def update_std_reward(metric_bag: MetricBag, std_reward):
metric_bag.get(Mean, "std_reward").update(std_reward, weight=1)

@torch.inference_mode()
def update_failed_to_parse_answers(metric_bag: MetricBag, failed_to_parse_answers):
metric_bag.get(Mean, "failed_to_parse_answers").update(failed_to_parse_answers, weight=1)


@torch.inference_mode()
Expand Down
Loading
Loading