Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
fd7267f
drgrpo
Oct 14, 2025
cb6f7a9
get vllm logps
Oct 14, 2025
d6acc63
Update _wandb.py
jacklanchantin Oct 14, 2025
7b72df9
remove beta check
Oct 14, 2025
7fc3b2f
Merge branch 'jacklanchantin/drgrpo' of github.com:facebookresearch/f…
Oct 14, 2025
502fa69
format
Oct 14, 2025
79382d3
revert
Oct 14, 2025
97e8dca
add importance sampling correction
Oct 16, 2025
54c9d98
dont run ref model forward if beta==0
Oct 20, 2025
acb0840
add tis ratio clamp = 2
Oct 20, 2025
50d21dd
clean up
Oct 20, 2025
ccfa63b
configs
Oct 21, 2025
bb49312
clean up
Oct 22, 2025
bd4b073
default
Oct 22, 2025
6919a4c
var name
Oct 22, 2025
d910891
var name
Oct 22, 2025
b762625
only use tis_imp_ratio_cap
Oct 22, 2025
5dff68a
batched inputs
Oct 23, 2025
cce97ce
use tis_drgrpo files
Oct 23, 2025
178fb69
size
Oct 24, 2025
536ce2b
match tis_grpo
Oct 24, 2025
a036e92
fix batching/microbatching bugs
Oct 24, 2025
ca043a5
black/isort
Oct 24, 2025
55dc39a
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
Oct 29, 2025
bdf6e4b
revert qwen card
Oct 29, 2025
cdbec3c
bypass reference_model if None
Oct 29, 2025
2645498
add SelfAugmentingExtractor for llm judge
Oct 31, 2025
7011bf9
sa judge
Nov 1, 2025
bfede6f
new metrics
Nov 3, 2025
9539d29
.
Nov 5, 2025
1f8bc99
ppl
Nov 6, 2025
b59f10b
clip outputs
jacklanchantin Nov 12, 2025
d500503
.
jacklanchantin Nov 13, 2025
b152d8b
fix rank=0 bug
jacklanchantin Nov 13, 2025
634a039
grpo
jacklanchantin Nov 13, 2025
071ff0f
remove ppl
jacklanchantin Nov 13, 2025
cd5ee91
remove ppl
jacklanchantin Nov 13, 2025
5cc0610
tokenizer
jacklanchantin Nov 13, 2025
7f7b50d
set VLLM_ALLOW_INSECURE_SERIALIZATION=1 for newer vllm versions and a…
lydiadli Oct 29, 2025
14d2571
comment out unused
jacklanchantin Nov 13, 2025
fcb24b2
logging
jacklanchantin Nov 14, 2025
8d40e31
logging
jacklanchantin Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/fairseq2/assets/cards/models/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ num_shards: 8

name: llama3
model_family: llama
checkpoint: "https://ai.meta.com/llama/;gated=true"
tokenizer: "https://ai.meta.com/llama/;gated=true"
checkpoint: "/datasets/pretrained-llms/Llama-3.1-8B/"
tokenizer: "/datasets/pretrained-llms/Llama-3.1-8B/"
tokenizer_family: llama
use_v2_tokenizer: true

---

name: llama3_instruct
base: llama3
checkpoint: "/datasets/pretrained-llms/Llama-3.1-8B-Instruct/"
tokenizer: "/datasets/pretrained-llms/Llama-3.1-8B-Instruct/"
use_eot: true # instruct tokenizer to use EOT instead of EOS

---
Expand Down Expand Up @@ -168,4 +170,4 @@ model_arch: llama3_1_8b
checkpoint: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer_family: llama
use_v2_tokenizer: true
use_v2_tokenizer: true
4 changes: 2 additions & 2 deletions src/fairseq2/assets/cards/models/qwen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ use_im_end: true
name: qwen3_8b_base
model_family: qwen
model_arch: qwen3_8b
checkpoint: "hg://qwen/qwen3-8b-base"
tokenizer: "hg://qwen/qwen3-8b-base"
checkpoint: "/checkpoint/data/jacklanchantin/pretrained-llms/Qwen3-8B-Base/"
tokenizer: "/checkpoint/data/jacklanchantin/pretrained-llms/Qwen3-8B-Base/"
tokenizer_family: qwen

---
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
device_stat_tracker: DeviceStatTracker,
wall_watch: Stopwatch,
progress_reporter: ProgressReporter,
fp16_loss_scale: tuple[float, float] = (128.0, 0.0001),
fp16_loss_scale: tuple[float, float] = (65536, 0.0001),
no_sync_grad_accumulation: bool = False,
max_grad_norm: float | None = None,
grad_check: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class TrainerSection:
max_grad_norm: float | None = None
"""The maximum gradient norm. If ``None``, no clipping will be applied."""

fp16_loss_scale: tuple[float, float] = (128.0, 0.0001)
fp16_loss_scale: tuple[float, float] = (65536, 0.0001)
"""The initial and minimum loss scale for fp16 training."""

gc_every_n_steps: int | None = None
Expand Down
6 changes: 6 additions & 0 deletions src/fairseq2/recipes/lm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -54,6 +54,12 @@
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PointwiseExtractorHandler as J1PointwiseExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
SelfAugmentingExtractor as SelfAugmentingExtractor,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
SelfAugmentingExtractorHandler as SelfAugmentingExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
JudgmentExtractorHandler as JudgmentExtractorHandler,
)
Expand Down
59 changes: 57 additions & 2 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -361,6 +361,7 @@

def get_vllm_logprobs(
vllm_outputs: List[RequestOutput],
model_logps: Tensor,
gangs,
rollout_start_end: tuple[int, int] | None = None,
):
Expand Down Expand Up @@ -404,7 +405,10 @@
padded = torch.zeros(len(sequences), max_len)
for i, t in enumerate(sequences):
padded[i, : t.size(0)] = t


# clip outputs to be same size as model_logps
if padded.size() != model_logps.size():
padded = padded[:, : model_logps.size(1)]
return padded


Expand Down Expand Up @@ -459,6 +463,46 @@
return rollout_lengths


def get_think_rollout_lengths(rollouts: List[SequenceData]):
"""Get the lengths of tokens before the </think> tag in rollouts.

This function calculates the approximate number of tokens generated before
the </think> closing tag in each rollout. It uses a proportional approximation
based on character positions to estimate token counts.

Args:
rollouts: List of SequenceData containing rollout outputs

Returns:
List of token lengths before </think> tag for rollouts that contain the tag
"""
think_rollout_lengths = []
think_tag = "</think>"

for rollout in rollouts:
for sample in rollout.outputs:
rollout_text = sample.text
if think_tag in rollout_text:
# Find the position of </think> in the text
think_end_pos = rollout_text.find(think_tag) + len(think_tag)
# Count tokens up to and including </think>
# We need to find how many tokens correspond to the text before </think>
# Since we have token_ids, we'll approximate by finding the proportion
text_before_think = rollout_text[:think_end_pos]
total_text = rollout_text
total_tokens = len(sample.token_ids)
# Approximate token count proportionally (rough estimate)
# A better approach would be to tokenize text_before_think, but we use approximation
think_token_length = (
int((len(text_before_think) / len(total_text)) * total_tokens)
if len(total_text) > 0
else 0
)
think_rollout_lengths.append(think_token_length)

return think_rollout_lengths


class StatefulRolloutBag:
"""A stateful container for managing and reusing model rollouts across multiple micro-batches.

Expand Down Expand Up @@ -559,6 +603,13 @@
metric_bag.get(Mean, "avg_rollout_length").update(avg_rollout_length, weight=1)


@torch.inference_mode()
def update_avg_think_rollout_length(metric_bag: MetricBag, avg_think_rollout_length):
metric_bag.get(Mean, "avg_think_rollout_length").update(
avg_think_rollout_length, weight=1
)


@torch.inference_mode()
def update_avg_reward_len_norm(metric_bag: MetricBag, avg_reward_len_norm):
metric_bag.get(Mean, "avg_reward_len_norm").update(avg_reward_len_norm, weight=1)
Expand Down Expand Up @@ -599,7 +650,7 @@


@torch.inference_mode()
def update_grpo_loss(metric_bag: MetricBag, batch: PromptBatch, loss: Tensor) -> None:
def update_grpo_loss(metric_bag: MetricBag, batch: PromptBatch, loss: Tensor, tis_imp_ratio: Tensor) -> None:
"""Update the GRPO loss metric.

:param batch:
Expand All @@ -611,6 +662,10 @@
loss / batch.batch_size, weight=batch.batch_size
)

metric_bag.get(Mean, "tis_imp_ratio").update(tis_imp_ratio)




def compute_reference_logps(
gangs: Gangs,
Expand Down
108 changes: 103 additions & 5 deletions src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
POINTWISE_J1_PROMPT = """
You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently.
You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently.

Think carefully about how to assess the quality of the response, and enclose your reasoning within <think> and </think> tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible.
Think carefully about how to assess the quality of the response, and enclose your reasoning within <think> and </think> tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible.

Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within <score> and </score> tags.

Format your output like this:
<think> your_thinking_process </think>
Format your output like this:
<think> your_thinking_process </think>
<score> your_score </score>

Below are the user's question and the assistant's response:
Expand Down Expand Up @@ -71,6 +71,26 @@
[The End of Assistant B's Answer]
"""

SELF_AUGMENTING_PROMPT = """
You are given a ground truth text, and a generated text from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response matches the ground truth text. It doesn't have to match word for word, but it should be very similar.

Think carefully about how to assess how well the generated text matches the ground truth. Your reasoning should include your evaluation criteria.

Finally, assign the assistant's generation a binary score, either 0 or 1. A 0 indicates that the generated text does not match the ground truth text, and a 1 indicates that it matches well.

Format your score as \\boxed{{SCORE}} where SCORE is either 0 or 1.

Below are the ground truth text and the assistant's Generation:

[Start of Ground Truth Text]
{ground_truth}
[End of Ground Truth Text]

[Start of Assistant's Generation]
{generation}
[End of Assistant's Generation]
"""


import re
from abc import ABC, abstractmethod
Expand All @@ -83,7 +103,7 @@

class JudgmentExtractorHandler(ABC):
@abstractmethod
def create(self): ...

Check failure on line 106 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation

@property
@abstractmethod
Expand Down Expand Up @@ -113,7 +133,7 @@
def prompt(self) -> str: ...

@abstractmethod
def format_prompt(self, prompt_text, **kwargs: Any) -> str: ...

Check failure on line 136 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation for one or more arguments

"""
Format the prompt text and additional arguments into a string suitable for input to the reward model.
Expand All @@ -126,7 +146,7 @@
"""

@abstractmethod
def extract(self, generation) -> float | str: ...

Check failure on line 149 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation for one or more arguments

"""
Extract the final scalar reward score from the model's response.
Expand All @@ -145,7 +165,7 @@
"""

@abstractmethod
def aggregate(self, judgments) -> float | str: ...

Check failure on line 168 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation for one or more arguments

"""
Aggregate multiple responses (judgments) from the reward model into a single value.
Expand All @@ -161,28 +181,28 @@


class GeneralVerifierExtractorHandler(JudgmentExtractorHandler):
def __init__(self):

Check failure on line 184 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation
pass

@override
def create(self):

Check failure on line 188 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation
return GeneralVerifierExtractor()

@property
@override
def name(self):

Check failure on line 193 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation
return "general_verifier_extractor"

@property
@override
def config_kls(self):

Check failure on line 198 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation
return None


class GeneralVerifierExtractor(JudgmentExtractor):
def __init__(self):

Check failure on line 203 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a return type annotation
try:
from math_verify import parse

Check failure on line 205 in src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Cannot find implementation or library stub for module named "math_verify"
from math_verify.parser import (
ExprExtractionConfig,
LatexExtractionConfig,
Expand Down Expand Up @@ -248,6 +268,82 @@
return round(avg_score / len(judgments), 4)


class SelfAugmentingExtractorHandler(JudgmentExtractorHandler):
def __init__(self):
pass

@override
def create(self):
return SelfAugmentingExtractor()

@property
@override
def name(self):
return "self_augmenting_extractor"

@property
@override
def config_kls(self):
return None


class SelfAugmentingExtractor(JudgmentExtractor):
def __init__(
self,
):
pass

@override
def prompt(self):
return SELF_AUGMENTING_PROMPT


def remove_think_tags(self, rollout_text):
tag = "</think>"
count = rollout_text.count(tag)
if count == 1:
# Find the position after the tag and return everything after it
index = rollout_text.find(tag) + len(tag)
return rollout_text[index:]
else:
return "" # set rollout to empty string if it doesn't contain thought or has multiple

@override
def format_prompt(self, tokenizer, prompt_text, rollout_text, reference_answer, dp_gangs):
# if dp_gangs.rank == 0
# breakpoint()
# dp_gangs.root.barrier()

rollout_text = self.remove_think_tags(rollout_text)

content = self.prompt().format(ground_truth=reference_answer, generation=rollout_text)

# log.info(f"Judge prompt = {content}")
wrapped_text = [{"role": "user", "content": content}]
chat_str = tokenizer.apply_chat_template(
wrapped_text, tokenize=False, add_generation_prompt=True
)
return chat_str

@override
def extract(self, generation):
# pattern = r'\\boxed\{(-?\d+)\}'
pattern = r'\\boxed\{([01])\}'
match = re.search(pattern, generation)
if match:
score = float(match.group(1))
else:
score = 0.0
return score

@override
def aggregate(self, judgments):
avg_score = 0.0
for score in judgments:
avg_score += score

return round(avg_score / len(judgments), 4)

class J1PointwiseExtractorHandler(JudgmentExtractorHandler):
def __init__(self):
pass
Expand All @@ -268,7 +364,9 @@


class J1PointwiseExtractor(JudgmentExtractor):
def __init__(self):
def __init__(
self,
):
pass

@override
Expand Down
Loading
Loading