From 6095032973bbc349e27275c29cdc69a8b0c451a3 Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 22 Jul 2025 00:12:58 +0000 Subject: [PATCH 01/33] skywork and some qwrn changes --- src/fairseq2/models/qwen/_hg.py | 2 + src/fairseq2/recipes/lm/__init__.py | 7 + .../recipes/lm/_online_finetune/_rewards.py | 168 ++++++++++++++++++ src/fairseq2/setup/_po_finetune_units.py | 5 + 4 files changed, 182 insertions(+) diff --git a/src/fairseq2/models/qwen/_hg.py b/src/fairseq2/models/qwen/_hg.py index b0317fabd..7a70d15c8 100644 --- a/src/fairseq2/models/qwen/_hg.py +++ b/src/fairseq2/models/qwen/_hg.py @@ -76,6 +76,8 @@ def _convert_parameter(name: str, r"decoder\.layers\.([0-9]+)\.ffn.output_proj\.": r"model.layers.\1.mlp.down_proj.", r"decoder\.layers\.([0-9]+)\.ffn.inner_proj\.": r"model.layers.\1.mlp.up_proj.", r"decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"model.layers.\1.input_layernorm.", + r"decoder\.layers\.([0-9]+)\.self_attn\.q_norm\.": r"model.layers.\1.self_attn.q_norm.", + r"decoder\.layers\.([0-9]+)\.self_attn\.k_norm\.": r"model.layers.\1.self_attn.k_norm.", r"decoder\.layer_norm\.": r"model.norm.", r"decoder_frontend.embed\.": r"model.embed_tokens.", r"final_proj\.": r"lm_head.", diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index dcdea9892..c959664b5 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -147,6 +147,13 @@ GSM8kVerifierHandler as GSM8kVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifier as SkyworkVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifierHandler as SkyworkVerifierHandler, +) + from fairseq2.recipes.lm._online_finetune._rewards import ( AtheneVerifier as AtheneVerifier, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index a65a2a765..c0d661454 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -274,6 +274,174 @@ def prepare_preference_batch( ) return batch, is_bad_batch, reward_output + +class SkyworkVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is not None: + tokenizer = reward_config.tokenizer + else: + tokenizer = "Skywork/Skywork-Reward-V2-Llama-3.1-8B" + + return SkyworkVerifier( + gangs, + context, + reward_model, + reward_name=reward_name, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=tokenizer, + ) + + @property + @override + def name(self): + return "skywork_verifier" + + @property + @override + def config_kls(self): + return None + +class SkyworkVerifier(VLLMOutputReward): + def __init__(self, gangs, context, reward_model, reward_name, answer_key, prompt_key, tokenizer): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + def wrap_text(self, prompt_text, rollout_text): + wrapped_text = [ + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": rollout_text}, + ] + chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False) + if self.tokenizer.bos_token is not None and chat_str.startswith(self.tokenizer.bos_token): + chat_str = chat_str[len(self.tokenizer.bos_token):] + + return chat_str + + @override + def process_rollouts( + self, vllm_outputs: List[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + + rollouts_text = [] + rollouts_tokens = [] + for rollout_output in i_batch_request_output.outputs: + rollout_text = rollout_output.text + vllm_input = self.wrap_text(prompt_text, rollout_text) + vllm_inputs.append(vllm_input) + rollouts_text.append(rollout_output.text) + rollouts_tokens.append(rollout_output.token_ids) + + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + batch_rewards = generate_rewards( + vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model + ) + + # reshape batch_rewards to [Batch, Rollouts] + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)] + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output + class AtheneVerifierHandler(VLLMOutputRewardHandler): diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index f41850db8..077e87951 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -20,6 +20,7 @@ OnlineFinetuneUnitHandler, GSM8kVerifierHandler, MathVerifyHandler, + SkyworkVerifierHandler, AtheneVerifierHandler, GenerativePointwiseVerifierHandler, GenerativePairwiseVerifierHandler, @@ -86,6 +87,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # GSM8kVerifier handler = GSM8kVerifierHandler() registry.register(handler.name, handler) + + # SkyworkVerifier + handler = SkyworkVerifierHandler() + registry.register(handler.name, handler) # AtheneVerifier handler = AtheneVerifierHandler() From f3d876a1a6c1272895274107b9588b8b6f863ea1 Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 22 Jul 2025 20:59:42 +0000 Subject: [PATCH 02/33] Removing think tokens --- src/fairseq2/recipes/lm/_online_finetune/_common.py | 11 +++++++++++ src/fairseq2/recipes/lm/_online_finetune/_grpo.py | 3 +++ 2 files changed, 14 insertions(+) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 39d408232..20e9657d7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -12,6 +12,7 @@ from typing import List, cast import ray +import re import torch import torch.nn as nn from torch import Tensor @@ -417,6 +418,16 @@ def get_rollout_lengths(rollouts: List[SequenceData]): rollout_lengths.append(token_ids_len) return rollout_lengths +def strip_think_tokens(rollouts: List[SequenceData]): + for sample in rollouts: + for rollout in sample.outputs: + rollout_text = rollout.text + rollout.text = re.sub( + r".*?", "", rollout_text, flags=re.DOTALL + ).strip() + + return rollouts + class StatefulRolloutBag: """A stateful container for managing and reusing model rollouts across multiple micro-batches. diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 2b19e9e7c..92e344833 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -47,6 +47,7 @@ compute_reference_logps, collate_with_target_mask, update_std_reward, + strip_think_tokens ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler from fairseq2.recipes.lm._online_finetune._remote_model import ( @@ -209,6 +210,7 @@ def validate_reward( ) if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") + rollouts = strip_think_tokens(rollouts) reward_output = self._reward.process_rollouts(rollouts, prompt_batch) log.info(f"Rewards: {reward_output['rewards']}") avg_reward = torch.tensor(reward_output["rewards"]).float().mean() @@ -266,6 +268,7 @@ def __call__( if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Train") + rollouts = strip_think_tokens(rollouts) reward_output = self._reward.process_rollouts(rollouts, prompt_batch) self._rollout_bag.save(rollouts, reward_output) From 164458b8dc33570d5de34d93f00e18c1da9cdb7f Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 23 Jul 2025 20:06:36 +0000 Subject: [PATCH 03/33] Fixing GRMs --- .../lm/_online_finetune/_generative_judge.py | 34 +++++++++---------- .../recipes/lm/_online_finetune/_rewards.py | 16 +++++---- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 492f1c70c..c7ebbdec2 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -81,7 +81,7 @@ class JudgmentExtractorHandler(ABC): @abstractmethod - def create(self): ... + def create(self, tokenizer): ... @property @abstractmethod @@ -163,8 +163,8 @@ def __init__(self): pass @override - def create(self): - return GeneralVerifierExtractor() + def create(self, tokenizer): + return GeneralVerifierExtractor(tokenizer) @property @override @@ -178,7 +178,7 @@ def config_kls(self): class GeneralVerifierExtractor(JudgmentExtractor): - def __init__(self): + def __init__(self, tokenizer): try: from math_verify import parse from math_verify.parser import ( @@ -251,8 +251,8 @@ def __init__(self): pass @override - def create(self): - return J1PointwiseExtractor() + def create(self, tokenizer): + return J1PointwiseExtractor(tokenizer) @property @override @@ -266,15 +266,15 @@ def config_kls(self): class J1PointwiseExtractor(JudgmentExtractor): - def __init__(self): - pass + def __init__(self, tokenizer): + self.tokenizer = tokenizer @override def prompt(self): return POINTWISE_J1_PROMPT @override - def format_prompt(self, prompt_text, rollout_text, reference_answer): + def format_prompt(self, prompt_text, rollout_text): content = self.prompt().format(instruction=prompt_text, response=rollout_text) wrapped_text = [{"role": "user", "content": content}] chat_str = self.tokenizer.apply_chat_template( @@ -305,8 +305,8 @@ def __init__(self): pass @override - def create(self): - return J1PairwiseScoreExtractor() + def create(self, tokenizer): + return J1PairwiseScoreExtractor(tokenizer) @property @override @@ -320,8 +320,8 @@ def config_kls(self): class J1PairwiseScoreExtractor(JudgmentExtractor): - def __init__(self): - pass + def __init__(self, tokenizer): + self.tokenizer = tokenizer @override def prompt(self): @@ -375,8 +375,8 @@ def __init__(self): pass @override - def create(self): - return J1PairwisePreferenceExtractor() + def create(self, tokenizer): + return J1PairwisePreferenceExtractor(tokenizer) @property @override @@ -390,8 +390,8 @@ def config_kls(self): class J1PairwisePreferenceExtractor(JudgmentExtractor): - def __init__(self): - pass + def __init__(self, tokenizer): + self.tokenizer = tokenizer @override def prompt(self): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index c0d661454..66967ba6a 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -695,7 +695,7 @@ def __init__( JudgmentExtractorHandler ) judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) - self.judgment_extractor = judgment_extractor_handler.create() + self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) @override def process_rollouts( @@ -716,12 +716,16 @@ def process_rollouts( rollouts_text = [] rollouts_tokens = [] - i_reference_answer = reference_answers[i] for rollout_output in i_batch_request_output.outputs: rollout_text = rollout_output.text - vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_text, i_reference_answer - ) + if reference_answers is None: + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_text + ) + else: + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_text, reference_answers[i] + ) vllm_inputs.append(vllm_input) rollouts_text.append(rollout_output.text) rollouts_tokens.append(rollout_output.token_ids) @@ -887,7 +891,7 @@ def __init__( JudgmentExtractorHandler ) judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) - self.judgment_extractor = judgment_extractor_handler.create() + self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) @override def process_rollouts( From 4fc9aea46de33f923961f5bacb5344d77d3b4b40 Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 23 Jul 2025 20:35:28 +0000 Subject: [PATCH 04/33] Black --- .../recipes/lm/_online_finetune/_common.py | 3 +- .../lm/_online_finetune/_generative_judge.py | 21 ++++++---- .../recipes/lm/_online_finetune/_grpo.py | 2 +- .../recipes/lm/_online_finetune/_handler.py | 9 +++-- .../lm/_online_finetune/_remote_model.py | 9 +++-- .../recipes/lm/_online_finetune/_rewards.py | 39 +++++++++++++------ 6 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index a8f079a97..5f00ed83a 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -418,6 +418,7 @@ def get_rollout_lengths(rollouts: List[SequenceData]): rollout_lengths.append(token_ids_len) return rollout_lengths + def strip_think_tokens(rollouts: List[SequenceData]): for sample in rollouts: for rollout in sample.outputs: @@ -425,7 +426,7 @@ def strip_think_tokens(rollouts: List[SequenceData]): rollout.text = re.sub( r".*?", "", rollout_text, flags=re.DOTALL ).strip() - + return rollouts diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index c7ebbdec2..a67b661c5 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -81,15 +81,18 @@ class JudgmentExtractorHandler(ABC): @abstractmethod - def create(self, tokenizer): ... + def create(self, tokenizer): + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... """ @@ -108,10 +111,12 @@ class JudgmentExtractor(ABC): """ @abstractmethod - def prompt(self) -> str: ... + def prompt(self) -> str: + ... @abstractmethod - def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... + def format_prompt(self, prompt_text, **kwargs: Any) -> str: + ... """ Format the prompt text and additional arguments into a string suitable for input to the reward model. @@ -124,7 +129,8 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... """ @abstractmethod - def extract(self, generation) -> float | str: ... + def extract(self, generation) -> float | str: + ... """ Extract the final scalar reward score from the model's response. @@ -143,7 +149,8 @@ def extract(self, generation) -> float | str: ... """ @abstractmethod - def aggregate(self, judgments) -> float | str: ... + def aggregate(self, judgments) -> float | str: + ... """ Aggregate multiple responses (judgments) from the reward model into a single value. diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 92e344833..fee95282c 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -47,7 +47,7 @@ compute_reference_logps, collate_with_target_mask, update_std_reward, - strip_think_tokens + strip_think_tokens, ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler from fairseq2.recipes.lm._online_finetune._remote_model import ( diff --git a/src/fairseq2/recipes/lm/_online_finetune/_handler.py b/src/fairseq2/recipes/lm/_online_finetune/_handler.py index d959d0cf7..8210d9d7b 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_handler.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_handler.py @@ -19,15 +19,18 @@ class OnlineFinetuneUnitHandler(ABC): @abstractmethod def create( self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object - ) -> TrainUnit[SequenceBatch]: ... + ) -> TrainUnit[SequenceBatch]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class UnknownOnlineFinetuneUnitError(Exception): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index 0ef2f5fe5..22ef1ddf1 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -565,15 +565,18 @@ class RemoteModelHandler(ABC): @abstractmethod def create( self, gangs: Gangs, unit_config: object - ) -> Union[RemoteVllmModel, RemoteHFModel]: ... + ) -> Union[RemoteVllmModel, RemoteHFModel]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class RemoteRayModelHandler(RemoteModelHandler): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 66967ba6a..e6f5493e1 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -50,23 +50,28 @@ class VLLMOutputRewardHandler(ABC): @abstractmethod def create( self, reward_model: Any, gangs: Gangs, reward_config: object - ) -> VLLMOutputReward: ... + ) -> VLLMOutputReward: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class VLLMOutputReward(ABC): @abstractmethod - def process_rollouts(self, vllm_outputs: list[RequestOutput]): ... + def process_rollouts(self, vllm_outputs: list[RequestOutput]): + ... @abstractmethod - def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): ... + def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): + ... class GSM8kVerifierHandler(VLLMOutputRewardHandler): @@ -274,7 +279,8 @@ def prepare_preference_batch( ) return batch, is_bad_batch, reward_output - + + class SkyworkVerifierHandler(VLLMOutputRewardHandler): def __init__(self): pass @@ -305,9 +311,19 @@ def name(self): @override def config_kls(self): return None - + + class SkyworkVerifier(VLLMOutputReward): - def __init__(self, gangs, context, reward_model, reward_name, answer_key, prompt_key, tokenizer): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + answer_key, + prompt_key, + tokenizer, + ): self.answer_key = answer_key self.prompt_key = prompt_key self._gangs = gangs @@ -322,8 +338,10 @@ def wrap_text(self, prompt_text, rollout_text): {"role": "assistant", "content": rollout_text}, ] chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False) - if self.tokenizer.bos_token is not None and chat_str.startswith(self.tokenizer.bos_token): - chat_str = chat_str[len(self.tokenizer.bos_token):] + if self.tokenizer.bos_token is not None and chat_str.startswith( + self.tokenizer.bos_token + ): + chat_str = chat_str[len(self.tokenizer.bos_token) :] return chat_str @@ -443,7 +461,6 @@ def prepare_preference_batch( return batch, is_bad_batch, reward_output - class AtheneVerifierHandler(VLLMOutputRewardHandler): def __init__(self): pass From f829d82418cfc70adeeab687506c70e8c7b20165 Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 29 Jul 2025 01:16:06 +0000 Subject: [PATCH 05/33] Import issue --- src/fairseq2/setup/_po_finetune_units.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index 101c318a2..c16ab5414 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -11,6 +11,7 @@ from fairseq2.context import RuntimeContext from fairseq2.recipes.lm import ( # GroupDpoFinetuneUnitHandler, AtheneVerifierHandler, + SkyworkVerifierHandler, CpoFinetuneUnitHandler, DpoFinetuneUnitHandler, GeneralVerifierExtractorHandler, From 9392f0d00cfc6e4226113275d71cb845434a5340 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Wed, 6 Aug 2025 18:26:46 +0000 Subject: [PATCH 06/33] add missing sw import --- src/fairseq2/recipes/lm/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 907c8a312..c5d2eee26 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -87,6 +87,12 @@ from fairseq2.recipes.lm._online_finetune._remote_model import ( RemoteModelHandler as RemoteModelHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifier as SkyworkVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifierHandler as SkyworkVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import ( AtheneVerifier as AtheneVerifier, ) From 55dc622e5d6d89c2d2d10b68fef44afb1ce3cf92 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 7 Aug 2025 00:00:25 +0000 Subject: [PATCH 07/33] Different configs for pairwise GRM --- .../recipes/lm/_online_finetune/_rewards.py | 148 +++++++++++++----- 1 file changed, 113 insertions(+), 35 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 892574180..1c52ffac9 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -10,6 +10,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any +from fairseq2.logging import log +import random import torch from transformers import AutoTokenizer @@ -347,7 +349,7 @@ def wrap_text(self, prompt_text, rollout_text): @override def process_rollouts( - self, vllm_outputs: List[RequestOutput], prompt_batch: PromptBatch + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch ): vllm_inputs = [] batch_text = [] @@ -753,6 +755,8 @@ def process_rollouts( batch_judgments = generate_rewards_generative( vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model ) + + log.info(f"Sample judgment: {batch_judgments[0].outputs[0].text}") batch_rewards = [] for per_rollout_judgments in batch_judgments: @@ -909,6 +913,98 @@ def __init__( ) judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) + + def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + for a in range(len(i_batch_request_output.outputs)): + for b in range(len(i_batch_request_output.outputs)): + if a != b: + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = i_batch_request_output.outputs[b].text + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text + ) + vllm_inputs.append(vllm_input) + prompt_pairwise_indices.append((a, b)) + + return vllm_inputs, prompt_pairwise_indices + + def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + reference_idx = random.randint(0, len(i_batch_request_output.outputs)-1) + reference_rollout = i_batch_request_output.outputs[reference_idx].text + for a in range(len(i_batch_request_output.outputs)): + if a != reference_idx: + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = reference_rollout + + to_swap = random.choice([True, False]) + if to_swap: + rollout_A_text, rollout_B_text = rollout_B_text, rollout_A_text + prompt_pairwise_indices.append((reference_idx, a)) + else: + prompt_pairwise_indices.append((a, reference_idx)) + + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text + ) + vllm_inputs.append(vllm_input) + + return vllm_inputs, prompt_pairwise_indices + + def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + all_pairs = [(i, j) for i in range(len(i_batch_request_output.outputs)) for j in range(len(i_batch_request_output.outputs)) if i != j] + random_pairs = random.sample(all_pairs, len(i_batch_request_output.outputs)) + + for a in range(len(i_batch_request_output.outputs)): + for b in range(len(i_batch_request_output.outputs)): + if (a, b) in random_pairs: + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = i_batch_request_output.outputs[b].text + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text + ) + vllm_inputs.append(vllm_input) + prompt_pairwise_indices.append((a, b)) + + return vllm_inputs, prompt_pairwise_indices + + def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pairwise_indices, batch_text, batch_type): + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_pointwise_rewards = [] + + for i in range(B): + # Extract the pairwise rewards for each input + if batch_type == "all_pairs": + prompt_pairwise_rewards = batch_pairwise_rewards[ + i * R * (R - 1) : (i + 1) * R * (R - 1) + ] + elif batch_type == "reference": + prompt_pairwise_rewards = batch_pairwise_rewards[ + i * (R - 1) : (i + 1) * (R - 1) + ] + elif batch_type == "random_pairs": + prompt_pairwise_rewards = batch_pairwise_rewards[ + i * R : (i + 1) * R + ] + + # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments + prompt_pairwise_indices = batch_pairwise_indices[i] + prompt_rewards = [0.0] * R + counts = [0] * R + for index, rewards in zip(prompt_pairwise_indices, prompt_pairwise_rewards): + prompt_rewards[index[0]] += rewards[0] + prompt_rewards[index[1]] += rewards[1] + counts[index[0]] += 1 + counts[index[1]] += 1 + + # Compute average pointwise rewards + avg_prompt_rewards = [] + for prompt_reward, count in zip(prompt_rewards, counts): + if count > 0: + avg_prompt_rewards.append(round(prompt_reward / count, 4)) + + batch_pointwise_rewards.append(avg_prompt_rewards) + + return batch_pointwise_rewards @override def process_rollouts( @@ -918,6 +1014,8 @@ def process_rollouts( batch_text = [] batch_tokens = [] batch_pairwise_indices = [] + + batch_type = "random_pairs" if vllm_outputs is None: vllm_outputs = [None] * len(prompt_batch.prompts) @@ -937,16 +1035,12 @@ def process_rollouts( batch_tokens.append(rollouts_tokens) prompt_pairwise_indices = [] - for a in range(len(i_batch_request_output.outputs)): - for b in range(len(i_batch_request_output.outputs)): - if a != b: - rollout_A_text = i_batch_request_output.outputs[a].text - rollout_B_text = i_batch_request_output.outputs[b].text - vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_A_text, rollout_B_text - ) - vllm_inputs.append(vllm_input) - prompt_pairwise_indices.append((a, b)) + if batch_type == "all_pairs": + vllm_inputs, prompt_pairwise_indices = self.all_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + elif batch_type == "reference": + vllm_inputs, prompt_pairwise_indices = self.pairs_with_reference(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + elif batch_type == "random_pairs": + vllm_inputs, prompt_pairwise_indices = self.random_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) batch_pairwise_indices.append(prompt_pairwise_indices) @@ -955,6 +1049,10 @@ def process_rollouts( dp_gang=self._gangs.dp, vllm_model=self.reward_model, ) + + log.info(f"Here: {len(batch_pairwise_judgments)}") + log.info(f"Here: {len(batch_pairwise_judgments[0].outputs)}") + log.info(f"Sample judgment: {batch_pairwise_judgments[0].outputs[0].text}") batch_pairwise_rewards = [] for per_rollout_judgments in batch_pairwise_judgments: @@ -966,30 +1064,10 @@ def process_rollouts( self.judgment_extractor.aggregate(per_rollout_rewards) ) - B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts - - # Logic to convert pairwise scores into pointwise rewards - # Can be done differently too - batch_rewards = [] - for i in range(B): - prompt_pairwise_rewards = batch_pairwise_rewards[ - i * R * (R - 1) : (i + 1) * R * (R - 1) - ] - prompt_pairwise_indices = batch_pairwise_indices[i] - prompt_rewards = [0.0] * R - for index, rewards in zip(prompt_pairwise_indices, prompt_pairwise_rewards): - prompt_rewards[index[0]] += rewards[0] - prompt_rewards[index[1]] += rewards[1] - - # Average score over 2*(R-1) pairwise comparisons - if (R - 1) > 0: - prompt_rewards = [ - round(prompt_reward / (2 * (R - 1)), 4) - for prompt_reward in prompt_rewards - ] - - batch_rewards.append(prompt_rewards) - + batch_rewards = self.convert_pairwise_rewards_to_pointwise(batch_pairwise_rewards, batch_pairwise_indices, batch_text, batch_type) + + log.info(f"Batch Rewards: {batch_rewards}") + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} def prepare_preference_batch( From ee17161477ee07cb402520703e943267e897a9d1 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 7 Aug 2025 00:03:04 +0000 Subject: [PATCH 08/33] Minor fix and more logging --- src/fairseq2/recipes/lm/_online_finetune/_grpo.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 40c606a9c..3a49ab6bc 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -190,15 +190,16 @@ def validate_reward( ) -> tuple[Tensor, int]: if self._gangs.dp.rank == 0: policy_sampling_params = copy(self._vllm_model.sampling_params) - # For a pairwise RM, need to sample at least two judgments - policy_sampling_params.n = ( - 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 - ) for ( k, v, ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) + + # For a pairwise RM, need to sample at least two judgments + policy_sampling_params.n = ( + 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 + ) else: policy_sampling_params = None rollouts = generate_rollouts( @@ -210,6 +211,8 @@ def validate_reward( if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") rollouts = strip_think_tokens(rollouts) + log.info(f"Sampling params: {len(rollouts[0].outputs)}") + log.info(f"Rollouts: {len(rollouts[0].outputs)}") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) log.info(f"Rewards: {reward_output['rewards']}") avg_reward = torch.tensor(reward_output["rewards"]).float().mean() @@ -264,8 +267,8 @@ def __call__( dp_gang=self._gangs.dp, vllm_model=self._vllm_model, ) - if self._config.loss_config.log_rollouts: - log_rollouts(prompt_batch, rollouts, "Train") + # if self._config.loss_config.log_rollouts: + # log_rollouts(prompt_batch, rollouts, "Train") rollouts = strip_think_tokens(rollouts) reward_output = self._reward.process_rollouts(rollouts, prompt_batch) From 18ff4c44bfe29d0cfa2a17dec92b264580bf46a7 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 7 Aug 2025 00:05:17 +0000 Subject: [PATCH 09/33] Online dpo: pairwise GRM should sample at least two rollouts --- src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py index 8f01c2ef7..f53bbda9d 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py @@ -142,12 +142,16 @@ def validate_reward( ) -> tuple[Tensor, int]: if self._gangs.dp.rank == 0: policy_sampling_params = copy(self._vllm_model.sampling_params) - policy_sampling_params.n = 1 for ( k, v, ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) + + # For a pairwise RM, need to sample at least two judgments + policy_sampling_params.n = ( + 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 + ) else: policy_sampling_params = None rollouts = generate_rollouts( From 585b744f55084d7ddc19dd7bce43154c0cc884a8 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 7 Aug 2025 01:38:37 +0000 Subject: [PATCH 10/33] zero reward for rollouts not involved in pairwise judgments --- src/fairseq2/recipes/lm/_online_finetune/_rewards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 1c52ffac9..17eeb42e0 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -997,7 +997,7 @@ def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pa counts[index[1]] += 1 # Compute average pointwise rewards - avg_prompt_rewards = [] + avg_prompt_rewards = [0.0] * R for prompt_reward, count in zip(prompt_rewards, counts): if count > 0: avg_prompt_rewards.append(round(prompt_reward / count, 4)) From 510bdf2c15c67813cc7d2a1d3a49b168caf9b366 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 7 Aug 2025 16:10:10 +0000 Subject: [PATCH 11/33] simplifying --- src/fairseq2/recipes/lm/_online_finetune/_rewards.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 17eeb42e0..bdf93084d 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -998,9 +998,9 @@ def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pa # Compute average pointwise rewards avg_prompt_rewards = [0.0] * R - for prompt_reward, count in zip(prompt_rewards, counts): - if count > 0: - avg_prompt_rewards.append(round(prompt_reward / count, 4)) + for i in range(len(prompt_rewards)): + if counts[i] > 0: + avg_prompt_rewards[i] = round(prompt_rewards[i] / counts[i], 4) batch_pointwise_rewards.append(avg_prompt_rewards) From 38aaf5368b90b3d3f0509def1faa83f482a6bf19 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Sat, 9 Aug 2025 22:38:13 +0000 Subject: [PATCH 12/33] SequenceBatch seq_lens type ensure to be a list --- src/fairseq2/recipes/lm/_online_finetune/_common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 7db60cc6d..307abcc8e 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -94,9 +94,13 @@ def collate_with_target_mask( seq_data = cast(SequenceData, collater(to_collate)) + seq_lens = seq_data["seqs"]["seq_lens"] + assert isinstance(seq_lens, Tensor) or isinstance(seq_lens, list) + if isinstance(seq_lens, Tensor): + seq_lens = seq_lens.tolist() batch = SequenceBatch( seq_data["seqs"]["seqs"], - seq_data["seqs"]["seq_lens"], + seq_lens, target_mask=seq_data["target_loss_mask"]["seqs"], ) batch.to(device) From a6ab8b0cb461abe0ce0ec04eb36dc22e0deee2d4 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Wed, 13 Aug 2025 17:36:45 +0000 Subject: [PATCH 13/33] add pairwsie J1 with reference answer --- src/fairseq2/recipes/lm/__init__.py | 6 ++ .../lm/_online_finetune/_generative_judge.py | 72 ++++++++++++++++++- .../recipes/lm/_online_finetune/_rewards.py | 25 ++++--- src/fairseq2/setup/_po_finetune_units.py | 4 ++ 4 files changed, 95 insertions(+), 12 deletions(-) diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index c5d2eee26..a761c9ed9 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -48,6 +48,12 @@ from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler, ) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1PairwiseScoreWithRefAnswerExtractor as J1PairwiseScoreWithRefAnswerExtractor, +) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1PairwiseScoreWithRefAnswerExtractorHandler as J1PairwiseScoreWithRefAnswerExtractorHandler, +) from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PointwiseExtractor as J1PointwiseExtractor, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 5c2d70198..56b6738f7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -71,6 +71,35 @@ [The End of Assistant B's Answer] """ +PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question, a reference answer, and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. + +First, think about your evaluation process and provide your reasoning within and tags. This could include your evaluation criteria for a high-quality response to this specific user question, an analysis of the reference answer, a detailed comparison of the two responses, etc. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. + +Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . + +Format your output like this: + your_thinking_process + your_score_a your_score_b + +Below are the user's question, reference answer and the two responses: + +[User Question] +{instruction} + +[Reference Answer] +{ref_answer} + +[The Start of Assistant A's Answer] +{response_A} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{response_B} +[The End of Assistant B's Answer] +""" import re from abc import ABC, abstractmethod @@ -337,7 +366,7 @@ def prompt(self): return PAIRWISE_WITH_SCORES_J1_PROMPT @override - def format_prompt(self, prompt_text, rollout_A_text, rollout_B_text): + def format_prompt(self, prompt_text, rollout_A_text, rollout_B_text, ref_answer=None): content = self.prompt().format( instruction=prompt_text, response_A=rollout_A_text, @@ -379,6 +408,47 @@ def aggregate(self, judgments): ) +class J1PairwiseScoreWithRefAnswerExtractorHandler(JudgmentExtractorHandler): + def __init__(self): + pass + + @override + def create(self, tokenizer): + return J1PairwiseScoreWithRefAnswerExtractor(tokenizer) + + @property + @override + def name(self): + return "j1_pairwise_score_extractor_with_ref_answer" + + @property + @override + def config_kls(self): + return None + + + +class J1PairwiseScoreWithRefAnswerExtractor(J1PairwiseScoreExtractor): + @override + def prompt(self): + return PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER + + @override + def format_prompt(self, prompt_text, rollout_A_text, rollout_B_text, ref_answer): + assert ref_answer is not None, "Reference answer must be provided" + content = self.prompt().format( + instruction=prompt_text, + ref_answer=ref_answer, + response_A=rollout_A_text, + response_B=rollout_B_text, + ) + wrapped_text = [{"role": "user", "content": content}] + chat_str = self.tokenizer.apply_chat_template( + wrapped_text, tokenize=False, add_generation_prompt=True + ) + return chat_str + + class J1PairwisePreferenceExtractorHandler(JudgmentExtractorHandler): def __init__(self): pass diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index bdf93084d..3e5e6e52d 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -914,21 +914,21 @@ def __init__( judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) - def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answer=None): for a in range(len(i_batch_request_output.outputs)): for b in range(len(i_batch_request_output.outputs)): if a != b: rollout_A_text = i_batch_request_output.outputs[a].text rollout_B_text = i_batch_request_output.outputs[b].text vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_A_text, rollout_B_text + prompt_text, rollout_A_text, rollout_B_text, reference_answer ) vllm_inputs.append(vllm_input) prompt_pairwise_indices.append((a, b)) return vllm_inputs, prompt_pairwise_indices - def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answer=None): reference_idx = random.randint(0, len(i_batch_request_output.outputs)-1) reference_rollout = i_batch_request_output.outputs[reference_idx].text for a in range(len(i_batch_request_output.outputs)): @@ -941,16 +941,15 @@ def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, rollout_A_text, rollout_B_text = rollout_B_text, rollout_A_text prompt_pairwise_indices.append((reference_idx, a)) else: - prompt_pairwise_indices.append((a, reference_idx)) - + prompt_pairwise_indices.append((a, reference_idx)) vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_A_text, rollout_B_text + prompt_text, rollout_A_text, rollout_B_text, reference_answer ) vllm_inputs.append(vllm_input) return vllm_inputs, prompt_pairwise_indices - def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answer=None): all_pairs = [(i, j) for i in range(len(i_batch_request_output.outputs)) for j in range(len(i_batch_request_output.outputs)) if i != j] random_pairs = random.sample(all_pairs, len(i_batch_request_output.outputs)) @@ -960,7 +959,7 @@ def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_ rollout_A_text = i_batch_request_output.outputs[a].text rollout_B_text = i_batch_request_output.outputs[b].text vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_A_text, rollout_B_text + prompt_text, rollout_A_text, rollout_B_text, reference_answer ) vllm_inputs.append(vllm_input) prompt_pairwise_indices.append((a, b)) @@ -1021,6 +1020,10 @@ def process_rollouts( vllm_outputs = [None] * len(prompt_batch.prompts) text_prompts = prompt_batch.meta_info.get(self.prompt_key) + try: + reference_answers = prompt_batch.meta_info.get(self.answer_key) + except: + reference_answers = [None] * len(prompt_batch.prompts) for i, (i_batch_request_output, prompt_text) in enumerate( zip(vllm_outputs, text_prompts) ): @@ -1036,11 +1039,11 @@ def process_rollouts( prompt_pairwise_indices = [] if batch_type == "all_pairs": - vllm_inputs, prompt_pairwise_indices = self.all_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + vllm_inputs, prompt_pairwise_indices = self.all_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answers[i]) elif batch_type == "reference": - vllm_inputs, prompt_pairwise_indices = self.pairs_with_reference(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + vllm_inputs, prompt_pairwise_indices = self.pairs_with_reference(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answers[i]) elif batch_type == "random_pairs": - vllm_inputs, prompt_pairwise_indices = self.random_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + vllm_inputs, prompt_pairwise_indices = self.random_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices, reference_answers[i]) batch_pairwise_indices.append(prompt_pairwise_indices) diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index c16ab5414..c38429966 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -20,6 +20,7 @@ GrpoFinetuneUnitHandler, GSM8kVerifierHandler, J1PairwiseScoreExtractorHandler, + J1PairwiseScoreWithRefAnswerExtractorHandler, J1PointwiseExtractorHandler, JudgmentExtractorHandler, MathVerifyHandler, @@ -127,5 +128,8 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: handler = J1PairwiseScoreExtractorHandler() registry.register(handler.name, handler) + handler = J1PairwiseScoreWithRefAnswerExtractorHandler() + registry.register(handler.name, handler) + handler = GeneralVerifierExtractorHandler() registry.register(handler.name, handler) From 2004533b28300ac3bdf391b1bf792100adc37b22 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Mon, 18 Aug 2025 09:27:30 +0000 Subject: [PATCH 14/33] fix None ref answer --- src/fairseq2/recipes/lm/_online_finetune/_rewards.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 3e5e6e52d..44bf787b0 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -1014,15 +1014,14 @@ def process_rollouts( batch_tokens = [] batch_pairwise_indices = [] - batch_type = "random_pairs" + batch_type = "reference" if vllm_outputs is None: vllm_outputs = [None] * len(prompt_batch.prompts) text_prompts = prompt_batch.meta_info.get(self.prompt_key) - try: - reference_answers = prompt_batch.meta_info.get(self.answer_key) - except: + reference_answers = prompt_batch.meta_info.get(self.answer_key) + if reference_answers is None: reference_answers = [None] * len(prompt_batch.prompts) for i, (i_batch_request_output, prompt_text) in enumerate( zip(vllm_outputs, text_prompts) From bfc255b0a23aec43eea903348803945f14b261ab Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 19 Aug 2025 21:35:17 +0000 Subject: [PATCH 15/33] Pairwise with pivot changes --- src/fairseq2/recipes/lm/__init__.py | 6 + .../recipes/lm/_online_finetune/_common.py | 13 ++ .../lm/_online_finetune/_generative_judge.py | 58 +++++++-- .../recipes/lm/_online_finetune/_rewards.py | 111 ++++++++++++++---- 4 files changed, 159 insertions(+), 29 deletions(-) diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index c5d2eee26..41af03bf7 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -99,6 +99,12 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( AtheneVerifierHandler as AtheneVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifier as SkyworkVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifierHandler as SkyworkVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import ( GenerativePairwiseVerifier as GenerativePairwiseVerifier, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 7db60cc6d..5bed1fe88 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -418,12 +418,25 @@ def get_rollout_lengths(rollouts: List[SequenceData]): def strip_think_tokens(rollouts: List[SequenceData]): + count_stripped, count_not_stripped, total_count, think_present = 0, 0, 0, 0 for sample in rollouts: for rollout in sample.outputs: rollout_text = rollout.text + if "" in rollout_text: + think_present += 1 + if rollout.finish_reason == "length": + count_stripped += 1 + if rollout.finish_reason == "stop": + count_not_stripped += 1 + total_count +=1 rollout.text = re.sub( r".*?", "", rollout_text, flags=re.DOTALL ).strip() + + log.info(f"Total count: {total_count}") + log.info(f"Think present: {think_present}") + log.info(f"Count stripped: {count_stripped/total_count}") + log.info(f"Count not stripped: {count_not_stripped/total_count}") return rollouts diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 5c2d70198..66b0be0c1 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -1,9 +1,28 @@ +# POINTWISE_J1_PROMPT = """ +# You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. + +# Think carefully about how to assess the quality of the response, and enclose your reasoning within and tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible. + +# Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. + +# Format your output like this: +# your_thinking_process +# your_score + +# Below are the user's question and the assistant's response: + +# [User Question] +# {instruction} + +# [The Start of the Assistant's Answer] +# {response} +# [The End of the Assistant's Answer] +# """ + POINTWISE_J1_PROMPT = """ You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. -Think carefully about how to assess the quality of the response, and enclose your reasoning within and tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible. - -Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. +Think carefully about how to assess the quality of the response and finally assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. Format your output like this: your_thinking_process @@ -44,14 +63,37 @@ [The End of Assistant B's Answer] """ -PAIRWISE_WITH_SCORES_J1_PROMPT = """ -You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. +# PAIRWISE_WITH_SCORES_J1_PROMPT = """ +# You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. -First, provide your reasoning within and tags. This should include your evaluation criteria for a high-quality response, a detailed comparison of the two responses, and when helpful, a reference answer as part of your evaluation. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. +# First, provide your reasoning within and tags. This should include your evaluation criteria for a high-quality response, a detailed comparison of the two responses, and when helpful, a reference answer as part of your evaluation. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. -Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +# Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +# Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . + +# Format your output like this: +# your_thinking_process +# your_score_a your_score_b + +# Below are the user's question and the two responses: + +# [User Question] +# {instruction} + +# [The Start of Assistant A's Answer] +# {response_A} +# [The End of Assistant A's Answer] + +# [The Start of Assistant B's Answer] +# {response_B} +# [The End of Assistant B's Answer] +# """ + +PAIRWISE_WITH_SCORES_J1_PROMPT = """ +You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . Format your output like this: your_thinking_process diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index bdf93084d..79a7b38d7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -914,7 +914,7 @@ def __init__( judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) - def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices): for a in range(len(i_batch_request_output.outputs)): for b in range(len(i_batch_request_output.outputs)): if a != b: @@ -924,11 +924,11 @@ def all_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pai prompt_text, rollout_A_text, rollout_B_text ) vllm_inputs.append(vllm_input) - prompt_pairwise_indices.append((a, b)) + batch_pairwise_indices.append((a, b)) - return vllm_inputs, prompt_pairwise_indices + return vllm_inputs, batch_pairwise_indices - def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices): reference_idx = random.randint(0, len(i_batch_request_output.outputs)-1) reference_rollout = i_batch_request_output.outputs[reference_idx].text for a in range(len(i_batch_request_output.outputs)): @@ -939,18 +939,42 @@ def pairs_with_reference(self, prompt_text, i_batch_request_output, vllm_inputs, to_swap = random.choice([True, False]) if to_swap: rollout_A_text, rollout_B_text = rollout_B_text, rollout_A_text - prompt_pairwise_indices.append((reference_idx, a)) + batch_pairwise_indices.append((reference_idx, a)) else: - prompt_pairwise_indices.append((a, reference_idx)) + batch_pairwise_indices.append((a, reference_idx)) vllm_input = self.judgment_extractor.format_prompt( prompt_text, rollout_A_text, rollout_B_text ) vllm_inputs.append(vllm_input) - return vllm_inputs, prompt_pairwise_indices + return vllm_inputs, batch_pairwise_indices - def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices): + def pairs_with_pivot(self, prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices, batch_pivot_pos): + pivot_idx = random.randint(0, len(i_batch_request_output.outputs)-1) + pivot_rollout = i_batch_request_output.outputs[pivot_idx].text + for a in range(len(i_batch_request_output.outputs)): + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = pivot_rollout + + batch_pairwise_indices.append((a, pivot_idx)) + batch_pivot_pos.append(1) # specifies which position is the reference index + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text + ) + vllm_inputs.append(vllm_input) + + batch_pairwise_indices.append((pivot_idx, a)) + batch_pivot_pos.append(0) + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_B_text, rollout_A_text + ) + vllm_inputs.append(vllm_input) + + + return vllm_inputs, batch_pairwise_indices, batch_pivot_pos + + def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices): all_pairs = [(i, j) for i in range(len(i_batch_request_output.outputs)) for j in range(len(i_batch_request_output.outputs)) if i != j] random_pairs = random.sample(all_pairs, len(i_batch_request_output.outputs)) @@ -963,9 +987,9 @@ def random_pairs(self, prompt_text, i_batch_request_output, vllm_inputs, prompt_ prompt_text, rollout_A_text, rollout_B_text ) vllm_inputs.append(vllm_input) - prompt_pairwise_indices.append((a, b)) + batch_pairwise_indices.append((a, b)) - return vllm_inputs, prompt_pairwise_indices + return vllm_inputs, batch_pairwise_indices def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pairwise_indices, batch_text, batch_type): B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts @@ -977,17 +1001,25 @@ def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pa prompt_pairwise_rewards = batch_pairwise_rewards[ i * R * (R - 1) : (i + 1) * R * (R - 1) ] - elif batch_type == "reference": + prompt_pairwise_indices = batch_pairwise_indices[ + i * R * (R - 1) : (i + 1) * R * (R - 1) + ] + elif batch_type == "pivot": prompt_pairwise_rewards = batch_pairwise_rewards[ i * (R - 1) : (i + 1) * (R - 1) ] + prompt_pairwise_indices = batch_pairwise_indices[ + i * (R - 1) : (i + 1) * (R - 1) + ] elif batch_type == "random_pairs": prompt_pairwise_rewards = batch_pairwise_rewards[ i * R : (i + 1) * R ] + prompt_pairwise_indices = batch_pairwise_indices[ + i * R : (i + 1) * R + ] # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments - prompt_pairwise_indices = batch_pairwise_indices[i] prompt_rewards = [0.0] * R counts = [0] * R for index, rewards in zip(prompt_pairwise_indices, prompt_pairwise_rewards): @@ -1005,6 +1037,45 @@ def convert_pairwise_rewards_to_pointwise(self, batch_pairwise_rewards, batch_pa batch_pointwise_rewards.append(avg_prompt_rewards) return batch_pointwise_rewards + + def convert_pairwise_rewards_to_pointwise_new(self, batch_pairwise_rewards, batch_pairwise_indices, batch_pivot_pos, batch_text, batch_type): + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_pointwise_rewards = [] + + for i in range(B): + # Extract the pairwise rewards for each input + assert batch_type == "pivot" + prompt_pairwise_rewards = batch_pairwise_rewards[ + i * 2 * R : (i + 1) * 2 * R + ] + prompt_pairwise_indices = batch_pairwise_indices[ + i * 2 * R : (i + 1) * 2 * R + ] + prompt_pivot_pos = batch_pivot_pos[ + i * 2 * R : (i + 1) * 2 * R + ] + + # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments + prompt_rewards = [0.0] * R + counts = [0] * R + for index, rewards, pivot_pos in zip(prompt_pairwise_indices, prompt_pairwise_rewards, prompt_pivot_pos): + non_pivot_pos = 1-pivot_pos + # Only compute rewards for the non_reference + # Rewards is a pair (score_A, score_B) + prompt_rewards[index[non_pivot_pos]] += rewards[non_pivot_pos] + # prompt_rewards[index[non_pivot_pos]] += (rewards[non_pivot_pos] - rewards[pivot_pos]) + counts[index[non_pivot_pos]] += 1 + + log.info(f"Counts: {counts}") + # Compute average pointwise rewards + avg_prompt_rewards = [0.0] * R + for i in range(len(prompt_rewards)): + if counts[i] > 0: + avg_prompt_rewards[i] = round(prompt_rewards[i] / counts[i], 4) + + batch_pointwise_rewards.append(avg_prompt_rewards) + + return batch_pointwise_rewards @override def process_rollouts( @@ -1014,8 +1085,9 @@ def process_rollouts( batch_text = [] batch_tokens = [] batch_pairwise_indices = [] + batch_pivot_pos = [] - batch_type = "random_pairs" + batch_type = "pivot" if vllm_outputs is None: vllm_outputs = [None] * len(prompt_batch.prompts) @@ -1034,15 +1106,12 @@ def process_rollouts( batch_text.append(rollouts_text) batch_tokens.append(rollouts_tokens) - prompt_pairwise_indices = [] if batch_type == "all_pairs": - vllm_inputs, prompt_pairwise_indices = self.all_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) - elif batch_type == "reference": - vllm_inputs, prompt_pairwise_indices = self.pairs_with_reference(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) + vllm_inputs, batch_pairwise_indices = self.all_pairs(prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices) + elif batch_type == "pivot": + vllm_inputs, batch_pairwise_indices, batch_pivot_pos = self.pairs_with_pivot(prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices, batch_pivot_pos) elif batch_type == "random_pairs": - vllm_inputs, prompt_pairwise_indices = self.random_pairs(prompt_text, i_batch_request_output, vllm_inputs, prompt_pairwise_indices) - - batch_pairwise_indices.append(prompt_pairwise_indices) + vllm_inputs, batch_pairwise_indices = self.random_pairs(prompt_text, i_batch_request_output, vllm_inputs, batch_pairwise_indices) batch_pairwise_judgments = generate_rewards_generative( vllm_inputs, @@ -1064,7 +1133,7 @@ def process_rollouts( self.judgment_extractor.aggregate(per_rollout_rewards) ) - batch_rewards = self.convert_pairwise_rewards_to_pointwise(batch_pairwise_rewards, batch_pairwise_indices, batch_text, batch_type) + batch_rewards = self.convert_pairwise_rewards_to_pointwise_new(batch_pairwise_rewards, batch_pairwise_indices, batch_pivot_pos, batch_text, batch_type) log.info(f"Batch Rewards: {batch_rewards}") From 5eee4eed4365a2484fc4b51ad7606213a2bab802 Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 20 Aug 2025 07:49:15 +0000 Subject: [PATCH 16/33] Fix --- src/fairseq2/setup/_po_finetune_units.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index c38429966..c16ab5414 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -20,7 +20,6 @@ GrpoFinetuneUnitHandler, GSM8kVerifierHandler, J1PairwiseScoreExtractorHandler, - J1PairwiseScoreWithRefAnswerExtractorHandler, J1PointwiseExtractorHandler, JudgmentExtractorHandler, MathVerifyHandler, @@ -128,8 +127,5 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: handler = J1PairwiseScoreExtractorHandler() registry.register(handler.name, handler) - handler = J1PairwiseScoreWithRefAnswerExtractorHandler() - registry.register(handler.name, handler) - handler = GeneralVerifierExtractorHandler() registry.register(handler.name, handler) From 5cdb6b9f699e70b448dd297e4888c7b299359b60 Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 20 Aug 2025 22:46:17 +0000 Subject: [PATCH 17/33] Making pair type configurable --- .../recipes/lm/_online_finetune/_rewards.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 06827a2b5..7f89ee41e 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -862,12 +862,18 @@ def create(self, reward_model, reward_name, reward_config, gangs, context): "Generative judges require implementing and specifying a judgment extractor" ) + if reward_config.pair_type is None: + raise RuntimeError( + "Pairwise generative judges require specifying how the pairs should be created" + ) + return GenerativePairwiseVerifier( gangs, context, reward_model, reward_name, judgment_extractor=reward_config.judgment_extractor, + pair_type=reward_config.pair_type, answer_key=reward_config.answer_key, prompt_key=reward_config.prompt_key, tokenizer=reward_config.tokenizer, @@ -892,6 +898,7 @@ def __init__( reward_model, reward_name, judgment_extractor, + pair_type, answer_key, prompt_key, tokenizer, @@ -903,6 +910,7 @@ def __init__( self.reward_model = reward_model self.reward_name = reward_name self.judgment_extractor = judgment_extractor + self.pair_type = pair_type self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) judgment_extractor_registry = self._context.get_registry( @@ -997,7 +1005,7 @@ def convert_pairwise_rewards_to_pointwise( batch_pairwise_rewards, batch_pairwise_indices, batch_text, - batch_type, + pair_type, batch_pivot_pos, ): B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts @@ -1005,11 +1013,11 @@ def convert_pairwise_rewards_to_pointwise( for i in range(B): # Extract the pairwise rewards for each input - if batch_type == "pivot": + if pair_type == "pivot": idx_start, idx_end = i * 2 * R, (i + 1) * 2 * R # 2R pairs - elif batch_type == "random_pairs": + elif pair_type == "random_pairs": idx_start, idx_end = i * R, (i + 1) * R # R pairs - elif batch_type == "all_pairs": + elif pair_type == "all_pairs": idx_start, idx_end = i * R * (R - 1), (i + 1) * R * ( R - 1 ) # R(R-1) pairs @@ -1020,7 +1028,7 @@ def convert_pairwise_rewards_to_pointwise( # If not pivot, create dummy pivots because both rewards will be considered prompt_pivot_pos = ( batch_pivot_pos[idx_start:idx_end] - if batch_type == "pivot" + if pair_type == "pivot" else [0] * (idx_end - idx_start + 1) ) @@ -1036,7 +1044,7 @@ def convert_pairwise_rewards_to_pointwise( counts[index[non_pivot_pos]] += 1 # If not pivot setup, consider rewards of the other (pivot) rollout as well - if batch_type != "pivot": + if pair_type != "pivot": prompt_rewards[index[non_pivot_pos]] += rewards[non_pivot_pos] counts[index[non_pivot_pos]] += 1 @@ -1062,8 +1070,6 @@ def process_rollouts( batch_pairwise_indices = [] batch_pivot_pos = [] - batch_type = "pivot" # all_pairs, pivot, random_pairs - if vllm_outputs is None: vllm_outputs = [None] * len(prompt_batch.prompts) @@ -1084,7 +1090,7 @@ def process_rollouts( batch_text.append(rollouts_text) batch_tokens.append(rollouts_tokens) - if batch_type == "all_pairs": + if self.pair_type == "all_pairs": vllm_inputs, batch_pairwise_indices = self.construct_all_pairs( prompt_text, i_batch_request_output, @@ -1092,7 +1098,7 @@ def process_rollouts( batch_pairwise_indices, reference_answers[i], ) - elif batch_type == "pivot": + elif self.pair_type == "pivot": ( vllm_inputs, batch_pairwise_indices, @@ -1105,7 +1111,7 @@ def process_rollouts( batch_pivot_pos, reference_answers[i], ) - elif batch_type == "random_pairs": + elif self.pair_type == "random_pairs": vllm_inputs, batch_pairwise_indices = self.construct_random_pairs( prompt_text, i_batch_request_output, @@ -1140,7 +1146,7 @@ def process_rollouts( batch_pairwise_rewards, batch_pairwise_indices, batch_text, - batch_type, + self.pair_type, batch_pivot_pos, ) From e14421d373e401ff78bdf8ef32839923939c94c6 Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 21 Aug 2025 00:25:58 +0000 Subject: [PATCH 18/33] Config change --- src/fairseq2/recipes/lm/_online_finetune/_rewards.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 7f89ee41e..eb5d6f168 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -40,6 +40,7 @@ class RewardModelConfig: prompt_key: str = "prompt" tokenizer: str | None = None judgment_extractor: str | None = None + pair_type: str | None = None @dataclass(kw_only=True) From 8831f362c4b48bcc663b501aefb4ee251ff622a9 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Sun, 24 Aug 2025 01:29:00 +0000 Subject: [PATCH 19/33] update prompt --- .../recipes/lm/_online_finetune/_generative_judge.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 95c5effdd..7f29de472 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -47,7 +47,7 @@ your_thinking_process your_score -Below are the user's question and the assistant's response: +Below are the user's question, reference answer and the assistant's response: [User Question] {instruction} @@ -111,13 +111,9 @@ """ PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ -You are given a user question, a reference answer, and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. - -First, think about your evaluation process and provide your reasoning within and tags. This could include your evaluation criteria for a high-quality response to this specific user question, an analysis of the reference answer, a detailed comparison of the two responses, etc. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. - -Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . Format your output like this: your_thinking_process From 9fc9dbb128bf2cbfa0e0feeee652a51fb4e3b470 Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 27 Aug 2025 16:56:42 +0000 Subject: [PATCH 20/33] some more logging --- src/fairseq2/recipes/lm/_online_finetune/_grpo.py | 5 ++++- .../recipes/lm/_online_finetune/_rewards.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index f5fbb9a7e..1cc995845 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -206,9 +206,10 @@ def validate_reward( vllm_model=self._vllm_model, sampling_params=policy_sampling_params, ) + rollouts = strip_think_tokens(rollouts) + log.info("After stripping") if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") - rollouts = strip_think_tokens(rollouts) log.info(f"Sampling params: {len(rollouts[0].outputs)}") log.info(f"Rollouts: {len(rollouts[0].outputs)}") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) @@ -269,6 +270,8 @@ def __call__( # log_rollouts(prompt_batch, rollouts, "Train") rollouts = strip_think_tokens(rollouts) + log.info('After stripping') + log_rollouts(prompt_batch, rollouts, "Train") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) self._rollout_bag.save(rollouts, reward_output) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index eb5d6f168..01fa521d4 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -971,6 +971,7 @@ def construct_pairs_with_pivot( vllm_inputs.append(vllm_input) return vllm_inputs, batch_pairwise_indices, batch_pivot_pos + def construct_random_pairs( self, @@ -1006,6 +1007,7 @@ def convert_pairwise_rewards_to_pointwise( batch_pairwise_rewards, batch_pairwise_indices, batch_text, + batch_tokens, pair_type, batch_pivot_pos, ): @@ -1050,12 +1052,22 @@ def convert_pairwise_rewards_to_pointwise( counts[index[non_pivot_pos]] += 1 log.info(f"Counts of each rollout: {counts}") + + log.info(f"Number of rollouts wrt batch tokens = {len(batch_tokens[i])}") + assert len(batch_tokens[i]) == R # Compute average pointwise rewards avg_prompt_rewards = [0.0] * R for j in range(len(prompt_rewards)): if counts[j] > 0: avg_prompt_rewards[j] = round(prompt_rewards[j] / counts[j], 4) + # num_tokens = len(batch_tokens[i][j]) + # log.info(f"Num tokens: {num_tokens}") + # correctness_reward = prompt_rewards[j] / counts[j] + # log.info(f"Correctness reward: {correctness_reward}") + # length_penalty = 0.001 * num_tokens + # avg_prompt_rewards[j] = round(correctness_reward - length_penalty, 4) + log.info(f"Overall reward: {avg_prompt_rewards[j]}") batch_pointwise_rewards.append(avg_prompt_rewards) @@ -1147,6 +1159,7 @@ def process_rollouts( batch_pairwise_rewards, batch_pairwise_indices, batch_text, + batch_tokens, self.pair_type, batch_pivot_pos, ) From b1ba0e246273d109e32e84b6ad398f112f6a70fd Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 27 Aug 2025 18:04:03 +0000 Subject: [PATCH 21/33] Fixing typo in comment --- src/fairseq2/recipes/lm/_online_finetune/_grpo.py | 2 +- src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 1cc995845..8042bdefb 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -194,7 +194,7 @@ def validate_reward( ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) - # For a pairwise RM, need to sample at least two judgments + # For a pairwise RM, need to sample at least two rollouts policy_sampling_params.n = ( 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py index 84a562ac6..0ef85f963 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py @@ -140,7 +140,7 @@ def validate_reward( ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) - # For a pairwise RM, need to sample at least two judgments + # For a pairwise RM, need to sample at least two rollouts policy_sampling_params.n = ( 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 ) From d47ef15860776f760e5cda6e780deb58cdbdea7e Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 2 Sep 2025 20:37:54 +0000 Subject: [PATCH 22/33] kwise judgment support --- src/fairseq2/recipes/lm/__init__.py | 12 + .../recipes/lm/_online_finetune/_common.py | 11 + .../lm/_online_finetune/_generative_judge.py | 169 ++++++++- .../recipes/lm/_online_finetune/_grpo.py | 23 +- .../recipes/lm/_online_finetune/_handler.py | 9 +- .../lm/_online_finetune/_remote_model.py | 9 +- .../recipes/lm/_online_finetune/_rewards.py | 327 +++++++++++++++++- src/fairseq2/setup/_po_finetune_units.py | 9 + 8 files changed, 524 insertions(+), 45 deletions(-) diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 15557ff98..686dfbe3e 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -48,6 +48,12 @@ from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler, ) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1KwiseScoreExtractor as J1KwiseScoreExtractor, +) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1KwiseScoreExtractorHandler as J1KwiseScoreExtractorHandler, +) from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PointwiseExtractor as J1PointwiseExtractor, ) @@ -111,6 +117,12 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( GenerativePointwiseVerifierHandler as GenerativePointwiseVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + GenerativeKwiseVerifier as GenerativeKwiseVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + GenerativeKwiseVerifierHandler as GenerativeKwiseVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import GSM8kVerifier as GSM8kVerifier from fairseq2.recipes.lm._online_finetune._rewards import ( GSM8kVerifierHandler as GSM8kVerifierHandler, diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 8079dae7d..8f1908862 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -439,6 +439,17 @@ def strip_think_tokens(rollouts: List[SequenceData]): return rollouts +def format_think_tags(rollouts: List[SequenceData]): + for sample in rollouts: + for rollout in sample.outputs: + rollout_text = rollout.text + rollout.text = rollout_text.replace( + "", "[Start of Assistant Thinking]" + ).replace("", "[End of Assistant Thinking]") + + return rollouts + + class StatefulRolloutBag: """A stateful container for managing and reusing model rollouts across multiple micro-batches. diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 7f29de472..9d6542283 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -20,7 +20,7 @@ # """ POINTWISE_J1_PROMPT = """ -You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. +You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. Do not allow the length of the response to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Think carefully about how to assess the quality of the response and finally assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. @@ -110,6 +110,72 @@ [The End of Assistant B's Answer] """ +KWISE_WITH_SCORES_J1_PROMPT = """ +You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + your_score_1 + your_score_2 + your_score_3 +... + +Below are the user's question and the two responses: + +[User Question] +{instruction} + +{responses} +""" + +KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + your_score_1 + your_score_2 + your_score_3 +... + +Below are the user's question and the two responses: + +[User Question] +{instruction} + +[Reference Answer] +{reference_answer} + +{responses} +""" + +# PAIRWISE_WITH_SCORES_J1_PROMPT = """ +# You are given a user question and two responses from two AI assistants. You are also given their thinking process. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Care any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +# Carefully analyze the assistants' thought process, assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . + +# Format your output like this: +# your_thinking_process +# your_score_a your_score_b + +# Below are the user's question and the two responses: + +# [User Question] +# {instruction} + +# [The Start of Assistant A's Answer] +# {response_A} +# [The End of Assistant A's Answer] + +# [The Start of Assistant B's Answer] +# {response_B} +# [The End of Assistant B's Answer] +# """ + PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. @@ -147,18 +213,15 @@ class JudgmentExtractorHandler(ABC): @abstractmethod - def create(self, tokenizer): - ... + def create(self, tokenizer): ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @property @abstractmethod - def config_kls(self) -> type[object]: - ... + def config_kls(self) -> type[object]: ... """ @@ -177,12 +240,10 @@ class JudgmentExtractor(ABC): """ @abstractmethod - def prompt(self) -> str: - ... + def prompt(self) -> str: ... @abstractmethod - def format_prompt(self, prompt_text, **kwargs: Any) -> str: - ... + def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... """ Format the prompt text and additional arguments into a string suitable for input to the reward model. @@ -195,8 +256,7 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str: """ @abstractmethod - def extract(self, generation) -> float | str: - ... + def extract(self, generation) -> float | str: ... """ Extract the final scalar reward score from the model's response. @@ -215,8 +275,7 @@ def extract(self, generation) -> float | str: """ @abstractmethod - def aggregate(self, judgments) -> float | str: - ... + def aggregate(self, judgments) -> float | str: ... """ Aggregate multiple responses (judgments) from the reward model into a single value. @@ -472,3 +531,83 @@ def aggregate(self, judgments): round(avg_score[0] / len(judgments), 4), round(avg_score[1] / len(judgments), 4), ) + + +class J1KwiseScoreExtractorHandler(JudgmentExtractorHandler): + def __init__(self): + pass + + @override + def create(self, tokenizer, k): + return J1KwiseScoreExtractor(tokenizer, k) + + @property + @override + def name(self): + return "j1_kwise_score_extractor" + + @property + @override + def config_kls(self): + return None + + +class J1KwiseScoreExtractor(JudgmentExtractor): + def __init__(self, tokenizer, k): + self.tokenizer = tokenizer + self.k = k + + @override + def prompt(self, reference_answer): + return ( + KWISE_WITH_SCORES_J1_PROMPT + if reference_answer is None + else KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER + ) + + @override + def format_prompt(self, prompt_text, rollouts, reference_answer): + prompt_template = self.prompt(reference_answer) + content = ( + prompt_template.format( + k=self.k, instruction=prompt_text, responses=rollouts + ) + if reference_answer is None + else prompt_template.format( + k=self.k, + instruction=prompt_text, + responses=rollouts, + reference_answer=reference_answer, + ) + ) + + wrapped_text = [{"role": "user", "content": content}] + chat_str = self.tokenizer.apply_chat_template( + wrapped_text, tokenize=False, add_generation_prompt=True + ) + return chat_str + + @override + def extract(self, generation): + scores = [] + for i in range(self.k): + score_matches = re.findall( + rf"\s*([0-9]+(?:\.[0-9])?)\s*(?:/10)?\s*", + generation, + ) + if score_matches: + scores.append(float(score_matches[-1].strip())) + else: + scores.append(0.0) + + return scores + + @override + def aggregate(self, judgments): + avg_score = [0.0] * self.k + for scores in judgments: + for i, score in enumerate(scores): + avg_score[i] += score + + avg_score = [round(avg_score[i] / len(judgments), 4) for i in range(self.k)] + return avg_score diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 8042bdefb..40d01bf1d 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -34,6 +34,7 @@ collate_with_target_mask, compute_reference_logps, compute_token_level_entropy, + format_think_tags, generate_rollouts, get_rollout_lengths, log_rollouts, @@ -195,9 +196,12 @@ def validate_reward( policy_sampling_params.__setattr__(k, v) # For a pairwise RM, need to sample at least two rollouts - policy_sampling_params.n = ( - 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 - ) + if self._reward.reward_name == "generative_pairwise_verifier": + policy_sampling_params.n = 2 + elif self._reward.reward_name == "generative_kwise_verifier": + policy_sampling_params.n = self._config.reward.config.k + else: + policy_sampling_params.n = 1 else: policy_sampling_params = None rollouts = generate_rollouts( @@ -206,7 +210,11 @@ def validate_reward( vllm_model=self._vllm_model, sampling_params=policy_sampling_params, ) - rollouts = strip_think_tokens(rollouts) + if self._config.reward.config.strip_thinking: + rollouts = strip_think_tokens(rollouts) + else: + rollouts = format_think_tags(rollouts) + log.info("After stripping") if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") @@ -269,8 +277,11 @@ def __call__( # if self._config.loss_config.log_rollouts: # log_rollouts(prompt_batch, rollouts, "Train") - rollouts = strip_think_tokens(rollouts) - log.info('After stripping') + if self._config.reward.config.strip_thinking: + rollouts = strip_think_tokens(rollouts) + else: + rollouts = format_think_tags(rollouts) + log.info("After stripping") log_rollouts(prompt_batch, rollouts, "Train") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) self._rollout_bag.save(rollouts, reward_output) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_handler.py b/src/fairseq2/recipes/lm/_online_finetune/_handler.py index 0badf2b10..943528f51 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_handler.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_handler.py @@ -19,18 +19,15 @@ class OnlineFinetuneUnitHandler(ABC): @abstractmethod def create( self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object - ) -> TrainUnit[SequenceBatch]: - ... + ) -> TrainUnit[SequenceBatch]: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @property @abstractmethod - def config_kls(self) -> type[object]: - ... + def config_kls(self) -> type[object]: ... class UnknownOnlineFinetuneUnitError(Exception): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index 22dd92aa7..607a3cbac 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -564,18 +564,15 @@ class RemoteModelHandler(ABC): @abstractmethod def create( self, gangs: Gangs, unit_config: object - ) -> Union[RemoteVllmModel, RemoteHFModel]: - ... + ) -> Union[RemoteVllmModel, RemoteHFModel]: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @property @abstractmethod - def config_kls(self) -> type[object]: - ... + def config_kls(self) -> type[object]: ... class RemoteRayModelHandler(RemoteModelHandler): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 01fa521d4..fa22e6c4b 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -6,6 +6,8 @@ from __future__ import annotations +import itertools +import math import random import re from abc import ABC, abstractmethod @@ -41,6 +43,8 @@ class RewardModelConfig: tokenizer: str | None = None judgment_extractor: str | None = None pair_type: str | None = None + k: int | None = None + strip_thinking: bool | None = None @dataclass(kw_only=True) @@ -53,28 +57,23 @@ class VLLMOutputRewardHandler(ABC): @abstractmethod def create( self, reward_model: Any, gangs: Gangs, reward_config: object - ) -> VLLMOutputReward: - ... + ) -> VLLMOutputReward: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @property @abstractmethod - def config_kls(self) -> type[object]: - ... + def config_kls(self) -> type[object]: ... class VLLMOutputReward(ABC): @abstractmethod - def process_rollouts(self, vllm_outputs: list[RequestOutput]): - ... + def process_rollouts(self, vllm_outputs: list[RequestOutput]): ... @abstractmethod - def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): - ... + def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): ... class GSM8kVerifierHandler(VLLMOutputRewardHandler): @@ -971,7 +970,6 @@ def construct_pairs_with_pivot( vllm_inputs.append(vllm_input) return vllm_inputs, batch_pairwise_indices, batch_pivot_pos - def construct_random_pairs( self, @@ -1052,7 +1050,7 @@ def convert_pairwise_rewards_to_pointwise( counts[index[non_pivot_pos]] += 1 log.info(f"Counts of each rollout: {counts}") - + log.info(f"Number of rollouts wrt batch tokens = {len(batch_tokens[i])}") assert len(batch_tokens[i]) == R @@ -1245,3 +1243,308 @@ def prepare_preference_batch( ) return batch, is_bad_batch, reward_output + + +class GenerativeKwiseVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is None: + raise RuntimeError("Generative judges require tokenizer") + + if reward_config.judgment_extractor is None: + raise RuntimeError( + "Generative judges require implementing and specifying a judgment extractor" + ) + + if reward_config.k is None: + raise RuntimeError( + "Kwise generative judges require specifying the size of the tuple k" + ) + + return GenerativeKwiseVerifier( + gangs, + context, + reward_model, + reward_name, + judgment_extractor=reward_config.judgment_extractor, + k=reward_config.k, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=reward_config.tokenizer, + ) + + @property + @override + def name(self): + return "generative_kwise_verifier" + + @property + @override + def config_kls(self): + return None + + +class GenerativeKwiseVerifier(VLLMOutputReward): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + judgment_extractor, + k, + answer_key, + prompt_key, + tokenizer, + ): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.judgment_extractor = judgment_extractor + self.k = k + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + judgment_extractor_registry = self._context.get_registry( + JudgmentExtractorHandler + ) + judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) + self.judgment_extractor = judgment_extractor_handler.create( + self.tokenizer, self.k + ) + + def construct_all_k_tuples( + self, + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_kwise_indices, + reference_answer, + R, + k, + ): + all_k_tuples = list(itertools.combinations(list(range(R)), k)) + for k_tuple in all_k_tuples: + k_list = list(k_tuple) + random.shuffle(k_list) + batch_kwise_indices.append(k_list) + response_string = "" + for assistant_id, idx in enumerate(k_list): + rollout = i_batch_request_output.outputs[idx].text + response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + response_string = response_string.strip() + + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, response_string, reference_answer + ) + vllm_inputs.append(vllm_input) + + return vllm_inputs, batch_kwise_indices + + def convert_kwise_rewards_to_pointwise( + self, + batch_kwise_rewards, + batch_kwise_indices, + batch_text, + batch_tokens, + k, + ): + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_pointwise_rewards = [] + + for prompt_idx in range(B): + # Extract the kwise rewards for each input + num = math.comb(R, k) + idx_start, idx_end = ( + prompt_idx * num, + (prompt_idx + 1) * num, + ) # R choose k tuples + + prompt_kwise_rewards = batch_kwise_rewards[idx_start:idx_end] + prompt_kwise_indices = batch_kwise_indices[idx_start:idx_end] + + # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments + prompt_rewards = [0.0] * R + counts = [0] * R + + # For example, indices would be [0, 3, 4] which means rollout 0, 3 and 4 + # rewards would be [7, 8, 9] which means rollout 0 has reward 7 and so on + for indices, rewards in zip(prompt_kwise_indices, prompt_kwise_rewards): + for rollout_idx in range(k): + prompt_rewards[indices[rollout_idx]] += rewards[rollout_idx] + counts[indices[rollout_idx]] += 1 + + log.info(f"Counts of each rollout: {counts}") + + log.info( + f"Number of rollouts wrt batch tokens = {len(batch_tokens[prompt_idx])}" + ) + assert len(batch_tokens[prompt_idx]) == R + + # Compute average pointwise rewards + avg_prompt_rewards = [0.0] * R + for j in range(R): + if counts[j] > 0: + avg_prompt_rewards[j] = round(prompt_rewards[j] / counts[j], 4) + log.info(f"Overall reward: {avg_prompt_rewards[j]}") + + batch_pointwise_rewards.append(avg_prompt_rewards) + + return batch_pointwise_rewards + + @override + def process_rollouts( + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + batch_kwise_indices = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + reference_answers = prompt_batch.meta_info.get(self.answer_key) + if reference_answers is None: + reference_answers = [None] * len(prompt_batch.prompts) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + rollouts_text = [ + rollout_output.text for rollout_output in i_batch_request_output.outputs + ] + rollouts_tokens = [ + rollout_output.token_ids + for rollout_output in i_batch_request_output.outputs + ] + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + R = len(rollouts_text) + vllm_inputs, batch_kwise_indices = self.construct_all_k_tuples( + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_kwise_indices, + reference_answers[i], + R, + self.k, + ) + + batch_kwise_judgments = generate_rewards_generative( + vllm_inputs, + dp_gang=self._gangs.dp, + vllm_model=self.reward_model, + ) + + log.info(f"Number of kwise comparisons: {len(batch_kwise_judgments)}") + log.info( + f"Number of judgments per kwise comparison: {len(batch_kwise_judgments[0].outputs)}" + ) + log.info(f"Sample judgment: {batch_kwise_judgments[0].outputs[0].text}") + + batch_kwise_rewards = [] + for per_rollout_judgments in batch_kwise_judgments: + per_rollout_rewards = [ + self.judgment_extractor.extract(judgment.text) + for judgment in per_rollout_judgments.outputs + ] + batch_kwise_rewards.append( + self.judgment_extractor.aggregate(per_rollout_rewards) + ) + + batch_rewards = self.convert_kwise_rewards_to_pointwise( + batch_kwise_rewards, + batch_kwise_indices, + batch_text, + batch_tokens, + self.k, + ) + + log.info(f"Batch Rewards: {batch_rewards}") + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index c16ab5414..86535afa3 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -16,10 +16,12 @@ DpoFinetuneUnitHandler, GeneralVerifierExtractorHandler, GenerativePairwiseVerifierHandler, + GenerativeKwiseVerifierHandler, GenerativePointwiseVerifierHandler, GrpoFinetuneUnitHandler, GSM8kVerifierHandler, J1PairwiseScoreExtractorHandler, + J1KwiseScoreExtractorHandler, J1PointwiseExtractorHandler, JudgmentExtractorHandler, MathVerifyHandler, @@ -107,6 +109,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # GenerativePairwiseVerifier handler = GenerativePairwiseVerifierHandler() registry.register(handler.name, handler) + + # GenerativeKwiseVerifier + handler = GenerativeKwiseVerifierHandler() + registry.register(handler.name, handler) registry = context.get_registry(RemoteModelHandler) @@ -126,6 +132,9 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: handler = J1PairwiseScoreExtractorHandler() registry.register(handler.name, handler) + + handler = J1KwiseScoreExtractorHandler() + registry.register(handler.name, handler) handler = GeneralVerifierExtractorHandler() registry.register(handler.name, handler) From d4740707ca2bc117b81c0ed50eaf3edd451653ec Mon Sep 17 00:00:00 2001 From: swarna Date: Wed, 3 Sep 2025 23:57:35 +0000 Subject: [PATCH 23/33] Adding support for acemath --- src/fairseq2/recipes/lm/__init__.py | 9 + .../recipes/lm/_online_finetune/_common.py | 2 + .../lm/_online_finetune/_remote_model.py | 29 ++- .../recipes/lm/_online_finetune/_rewards.py | 181 ++++++++++++++++++ .../_online_finetune/third_party/ace_math.py | 28 +++ src/fairseq2/setup/_po_finetune_units.py | 10 + 6 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 686dfbe3e..0b46460df 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -90,6 +90,9 @@ from fairseq2.recipes.lm._online_finetune._remote_model import ( NoEnvGeneralVerifierPipeline as NoEnvGeneralVerifierPipeline, ) +from fairseq2.recipes.lm._online_finetune._remote_model import ( + NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline, +) from fairseq2.recipes.lm._online_finetune._remote_model import ( RemoteModelHandler as RemoteModelHandler, ) @@ -105,6 +108,12 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( SkyworkVerifierHandler as SkyworkVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + AceMathVerifier as AceMathVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + AceMathVerifierHandler as AceMathVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import ( GenerativePairwiseVerifier as GenerativePairwiseVerifier, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 8f1908862..147f16876 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -394,6 +394,8 @@ def log_rollouts(prompt_batch: PromptBatch, rollouts, split_name, num_rollouts=1 prompt = prompt_batch.meta_info.get("prompt_raw")[0] elif "raw_prompt" in prompt_batch.meta_info: prompt = prompt_batch.meta_info.get("raw_prompt")[0] + elif "problem" in prompt_batch.meta_info: + prompt = prompt_batch.meta_info.get("problem")[0] else: # raw text prompt doesn't exist for this dataset prompt = "DUMMY PROMPT" diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index 607a3cbac..c836e25e4 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -31,6 +31,9 @@ from fairseq2.recipes.lm._online_finetune.third_party.general_verifier import ( GeneralVerifierPipeline, ) +from fairseq2.recipes.lm._online_finetune.third_party.ace_math import ( + AceMathRMPipeline, +) from fairseq2.utils.structured import StructureError, structure @@ -48,6 +51,7 @@ class VllmEngineArgs: tokenizer: str = "/datasets/pretrained-llms/Llama-3.1-8B-Instruct" task: str = "generate" tensor_parallel_size: int = 4 + max_model_len: int | None = None trust_remote_code: bool = False model_impl: str = "auto" enforce_eager: bool = True @@ -136,6 +140,26 @@ def is_ready(self): @property def name(self): return "general_verifier_pipeline" + +@ray.remote +class NoEnvAceMathRMPipeline(AceMathRMPipeline): + """ + This is for running Ace Math RM pipeline with HF backend. + """ + + def __init__(self, *args, **kwargs): + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + super().__init__(*args, **kwargs) + self.ready = True # Set a flag or return a signal + + def is_ready(self): + return self.ready + + @property + def name(self): + return "ace_math_rm_pipeline" class WorkerExtension: @@ -309,6 +333,7 @@ def setup_vllm_worker(self, ray_actor_name, vllm_engine_args, gangs: Gangs): ).remote( model=vllm_engine_args.model, tokenizer=vllm_engine_args.tokenizer, + max_model_len=vllm_engine_args.max_model_len, enforce_eager=vllm_engine_args.enforce_eager, worker_extension_cls="fairseq2.recipes.lm._online_finetune._remote_model.WorkerExtension", tensor_parallel_size=vllm_engine_args.tensor_parallel_size, @@ -437,6 +462,8 @@ def reward_from_model(self, prompt_list, batch_size=64): ray_outputs = ray.get(outputs) ray_outputs_flat = [o for sublist in ray_outputs for o in sublist] rewards = [o.outputs.data.item() for o in ray_outputs_flat] + + log.info(f"Rewards = {rewards}") return rewards @@ -537,7 +564,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None, string_input=Fal "RemoteHFModel.rollout_from_model is not implemented. " ) - def reward_from_model(self, prompt_list, batch_size=64): + def reward_from_model(self, prompt_list, batch_size=4): # NOTE: need to batch inputs to hf.encode model for current models that aren't supported by hf rewards = [] outputs = [] diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index fa22e6c4b..7afc2a444 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -461,6 +461,187 @@ def prepare_preference_batch( ) return batch, is_bad_batch, reward_output + +class AceMathVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is not None: + tokenizer = reward_config.tokenizer + else: + tokenizer = "nvidia/AceMath-7B-RM" + + return AceMathVerifier( + gangs, + context, + reward_model, + reward_name=reward_name, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=tokenizer, + ) + + @property + @override + def name(self): + return "acemath_verifier" + + @property + @override + def config_kls(self): + return None + +class AceMathVerifier(VLLMOutputReward): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + answer_key, + prompt_key, + tokenizer, + ): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + def wrap_text(self, prompt_text, rollout_text): + wrapped_text = [ + {"role": "system", "content": "Please reason step by step, and check your final answer within \\boxed{}."}, + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": rollout_text} + ] + chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False, add_generation_prompt=False) + if self.tokenizer.bos_token is not None and chat_str.startswith( + self.tokenizer.bos_token + ): + chat_str = chat_str[len(self.tokenizer.bos_token) :] + + return chat_str + + @override + def process_rollouts( + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + + rollouts_text = [] + rollouts_tokens = [] + for rollout_output in i_batch_request_output.outputs: + rollout_text = rollout_output.text + vllm_input = self.wrap_text(prompt_text, rollout_text) + vllm_inputs.append(vllm_input) + rollouts_text.append(rollout_output.text) + rollouts_tokens.append(rollout_output.token_ids) + + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + batch_rewards = generate_rewards( + vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model + ) + + log.info(f"Batch rewards = {batch_rewards}") + + # reshape batch_rewards to [Batch, Rollouts] + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)] + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output class AtheneVerifierHandler(VLLMOutputRewardHandler): diff --git a/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py new file mode 100644 index 000000000..ff48fd0c1 --- /dev/null +++ b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py @@ -0,0 +1,28 @@ +import torch +from fairseq2.logging import log +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +class AceMathRMPipeline: + def __init__(self, *args, **kwargs): + model_path = "/datasets/pretrained-llms/AceMath-7B-RM" + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, num_labels=1, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map = "auto" + ).eval() + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def __call__(self, prompt_chunk): + inputs = self.tokenizer( + prompt_chunk, + return_tensors="pt", + padding=True, + add_special_tokens=False + ).to(self.model.device) + + outputs = self.model(**inputs)[0] + log.info(f"outputs = {outputs}") + rewards =[output[0] for output in outputs] + + log.info(f"Length of rewards = {len(rewards)}") + + return rewards diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index 86535afa3..83890a060 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -12,6 +12,7 @@ from fairseq2.recipes.lm import ( # GroupDpoFinetuneUnitHandler, AtheneVerifierHandler, SkyworkVerifierHandler, + AceMathVerifierHandler, CpoFinetuneUnitHandler, DpoFinetuneUnitHandler, GeneralVerifierExtractorHandler, @@ -27,6 +28,7 @@ MathVerifyHandler, NoEnvAtheneRewardPipeline, NoEnvGeneralVerifierPipeline, + NoEnvAceMathRMPipeline, OnlineDpoFinetuneUnitHandler, OnlineFinetuneUnitHandler, OrpoFinetuneUnitHandler, @@ -93,6 +95,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # SkyworkVerifier handler = SkyworkVerifierHandler() registry.register(handler.name, handler) + + # AceMath RM + handler = AceMathVerifierHandler() + registry.register(handler.name, handler) # AtheneVerifier handler = AtheneVerifierHandler() @@ -123,6 +129,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # NoEnvGeneralVerifierPipeline handler = NoEnvGeneralVerifierPipeline registry.register(handler.name, handler) + + # NoEnvAceMathRMPipeline + handler = NoEnvAceMathRMPipeline + registry.register(handler.name, handler) # Generative judgment extractors registry = context.get_registry(JudgmentExtractorHandler) From 1162d6022f4e3fb71b7dfe9021a445333a8286dd Mon Sep 17 00:00:00 2001 From: swarna Date: Thu, 4 Sep 2025 23:42:15 +0000 Subject: [PATCH 24/33] Skywork-RM from hf --- src/fairseq2/recipes/lm/__init__.py | 3 ++ .../lm/_online_finetune/_generative_judge.py | 21 ++++++---- .../recipes/lm/_online_finetune/_handler.py | 9 +++-- .../lm/_online_finetune/_remote_model.py | 40 +++++++++++++++---- .../recipes/lm/_online_finetune/_rewards.py | 34 +++++++++++----- .../_online_finetune/third_party/ace_math.py | 25 +++++++----- .../_online_finetune/third_party/skywork.py | 30 ++++++++++++++ 7 files changed, 123 insertions(+), 39 deletions(-) create mode 100644 src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 0b46460df..2f4347536 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -93,6 +93,9 @@ from fairseq2.recipes.lm._online_finetune._remote_model import ( NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline, ) +from fairseq2.recipes.lm._online_finetune._remote_model import ( + NoEnvSkyworkRMPipeline as NoEnvSkyworkRMPipeline, +) from fairseq2.recipes.lm._online_finetune._remote_model import ( RemoteModelHandler as RemoteModelHandler, ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 9d6542283..9e82273d7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -213,15 +213,18 @@ class JudgmentExtractorHandler(ABC): @abstractmethod - def create(self, tokenizer): ... + def create(self, tokenizer): + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... """ @@ -240,10 +243,12 @@ class JudgmentExtractor(ABC): """ @abstractmethod - def prompt(self) -> str: ... + def prompt(self) -> str: + ... @abstractmethod - def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... + def format_prompt(self, prompt_text, **kwargs: Any) -> str: + ... """ Format the prompt text and additional arguments into a string suitable for input to the reward model. @@ -256,7 +261,8 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... """ @abstractmethod - def extract(self, generation) -> float | str: ... + def extract(self, generation) -> float | str: + ... """ Extract the final scalar reward score from the model's response. @@ -275,7 +281,8 @@ def extract(self, generation) -> float | str: ... """ @abstractmethod - def aggregate(self, judgments) -> float | str: ... + def aggregate(self, judgments) -> float | str: + ... """ Aggregate multiple responses (judgments) from the reward model into a single value. diff --git a/src/fairseq2/recipes/lm/_online_finetune/_handler.py b/src/fairseq2/recipes/lm/_online_finetune/_handler.py index 943528f51..0badf2b10 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_handler.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_handler.py @@ -19,15 +19,18 @@ class OnlineFinetuneUnitHandler(ABC): @abstractmethod def create( self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object - ) -> TrainUnit[SequenceBatch]: ... + ) -> TrainUnit[SequenceBatch]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class UnknownOnlineFinetuneUnitError(Exception): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index c836e25e4..8e10249d4 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -27,13 +27,12 @@ from fairseq2.gang import Gangs from fairseq2.logging import log from fairseq2.nn._batch_layout import BatchLayout +from fairseq2.recipes.lm._online_finetune.third_party.ace_math import AceMathRMPipeline from fairseq2.recipes.lm._online_finetune.third_party.athene import AtheneRewardPipeline from fairseq2.recipes.lm._online_finetune.third_party.general_verifier import ( GeneralVerifierPipeline, ) -from fairseq2.recipes.lm._online_finetune.third_party.ace_math import ( - AceMathRMPipeline, -) +from fairseq2.recipes.lm._online_finetune.third_party.skywork import SkyworkRMPipeline from fairseq2.utils.structured import StructureError, structure @@ -140,7 +139,8 @@ def is_ready(self): @property def name(self): return "general_verifier_pipeline" - + + @ray.remote class NoEnvAceMathRMPipeline(AceMathRMPipeline): """ @@ -162,6 +162,27 @@ def name(self): return "ace_math_rm_pipeline" +@ray.remote +class NoEnvSkyworkRMPipeline(SkyworkRMPipeline): + """ + This is for running Ace Math RM pipeline with HF backend. + """ + + def __init__(self, *args, **kwargs): + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + super().__init__(*args, **kwargs) + self.ready = True # Set a flag or return a signal + + def is_ready(self): + return self.ready + + @property + def name(self): + return "skywork_rm_pipeline" + + class WorkerExtension: """ The class for vLLM's worker to inherit from. @@ -462,7 +483,7 @@ def reward_from_model(self, prompt_list, batch_size=64): ray_outputs = ray.get(outputs) ray_outputs_flat = [o for sublist in ray_outputs for o in sublist] rewards = [o.outputs.data.item() for o in ray_outputs_flat] - + log.info(f"Rewards = {rewards}") return rewards @@ -591,15 +612,18 @@ class RemoteModelHandler(ABC): @abstractmethod def create( self, gangs: Gangs, unit_config: object - ) -> Union[RemoteVllmModel, RemoteHFModel]: ... + ) -> Union[RemoteVllmModel, RemoteHFModel]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class RemoteRayModelHandler(RemoteModelHandler): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 7afc2a444..9977a71dd 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -57,23 +57,28 @@ class VLLMOutputRewardHandler(ABC): @abstractmethod def create( self, reward_model: Any, gangs: Gangs, reward_config: object - ) -> VLLMOutputReward: ... + ) -> VLLMOutputReward: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class VLLMOutputReward(ABC): @abstractmethod - def process_rollouts(self, vllm_outputs: list[RequestOutput]): ... + def process_rollouts(self, vllm_outputs: list[RequestOutput]): + ... @abstractmethod - def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): ... + def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): + ... class GSM8kVerifierHandler(VLLMOutputRewardHandler): @@ -461,7 +466,8 @@ def prepare_preference_batch( ) return batch, is_bad_batch, reward_output - + + class AceMathVerifierHandler(VLLMOutputRewardHandler): def __init__(self): pass @@ -492,7 +498,8 @@ def name(self): @override def config_kls(self): return None - + + class AceMathVerifier(VLLMOutputReward): def __init__( self, @@ -514,11 +521,16 @@ def __init__( def wrap_text(self, prompt_text, rollout_text): wrapped_text = [ - {"role": "system", "content": "Please reason step by step, and check your final answer within \\boxed{}."}, + { + "role": "system", + "content": "Please reason step by step, and check your final answer within \\boxed{}.", + }, {"role": "user", "content": prompt_text}, - {"role": "assistant", "content": rollout_text} + {"role": "assistant", "content": rollout_text}, ] - chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False, add_generation_prompt=False) + chat_str = self.tokenizer.apply_chat_template( + wrapped_text, tokenize=False, add_generation_prompt=False + ) if self.tokenizer.bos_token is not None and chat_str.startswith( self.tokenizer.bos_token ): @@ -557,7 +569,7 @@ def process_rollouts( batch_rewards = generate_rewards( vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model ) - + log.info(f"Batch rewards = {batch_rewards}") # reshape batch_rewards to [Batch, Rollouts] diff --git a/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py index ff48fd0c1..3fc33275d 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py +++ b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py @@ -1,28 +1,33 @@ import torch -from fairseq2.logging import log from transformers import AutoModelForSequenceClassification, AutoTokenizer +from fairseq2.logging import log + + class AceMathRMPipeline: def __init__(self, *args, **kwargs): model_path = "/datasets/pretrained-llms/AceMath-7B-RM" - self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) self.model = AutoModelForSequenceClassification.from_pretrained( - model_path, num_labels=1, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map = "auto" + model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", ).eval() self.model.config.pad_token_id = self.tokenizer.pad_token_id def __call__(self, prompt_chunk): inputs = self.tokenizer( - prompt_chunk, - return_tensors="pt", - padding=True, - add_special_tokens=False + prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False ).to(self.model.device) - + outputs = self.model(**inputs)[0] log.info(f"outputs = {outputs}") - rewards =[output[0] for output in outputs] - + rewards = [output[0] for output in outputs] + log.info(f"Length of rewards = {len(rewards)}") return rewards diff --git a/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py b/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py new file mode 100644 index 000000000..3a60da8e7 --- /dev/null +++ b/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py @@ -0,0 +1,30 @@ +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from fairseq2.logging import log + + +class SkyworkRMPipeline: + def __init__(self, *args, **kwargs): + model_path = "/datasets/pretrained-llms/Skywork-Reward-V2-Llama-3.1-8B" + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def __call__(self, prompt_chunk): + inputs = self.tokenizer( + prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False + ).to(self.model.device) + + outputs = self.model(**inputs)[0] + rewards = [output[0] for output in outputs] + + return rewards From 4ea811d005caf5dbcca70effc9f1ff4799e964f0 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Sat, 6 Sep 2025 17:39:56 +0000 Subject: [PATCH 25/33] add parsed ref --- .../lm/_online_finetune/_generative_judge.py | 186 ++++++++++-------- 1 file changed, 104 insertions(+), 82 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 9e82273d7..8c4e0ca3a 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -1,32 +1,11 @@ -# POINTWISE_J1_PROMPT = """ -# You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. - -# Think carefully about how to assess the quality of the response, and enclose your reasoning within and tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible. - -# Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. - -# Format your output like this: -# your_thinking_process -# your_score - -# Below are the user's question and the assistant's response: - -# [User Question] -# {instruction} - -# [The Start of the Assistant's Answer] -# {response} -# [The End of the Assistant's Answer] -# """ - POINTWISE_J1_PROMPT = """ -You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. Do not allow the length of the response to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. -Think carefully about how to assess the quality of the response and finally assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. +Think carefully about how to assess the quality of the response and assign the assistant's response a score 1 if the response is correct, and 0 if not. Enclose the score within and tags. Format your output like this: your_thinking_process - your_score + 0 or 1 Below are the user's question and the assistant's response: @@ -38,14 +17,35 @@ [The End of the Assistant's Answer] """ + +# POINTWISE_J1_PROMPT = """ +# You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. + +# Think carefully about how to assess the quality of the response and assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. + +# Format your output like this: +# your_thinking_process +# your_score + +# Below are the user's question and the assistant's response: + +# [User Question] +# {instruction} + +# [The Start of the Assistant's Answer] +# {response} +# [The End of the Assistant's Answer] +# """ + + POINTWISE_J1_PROMPT_WITH_REF_ANSWER = """ You are given a user question, a reference answer and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. -Think carefully about how to assess the quality of the response and finally assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. +Think carefully about how to assess the quality of the response and assign the assistant's response a score 1 if the response is correct, and 0 if not. Enclose the score within and tags. Format your output like this: your_thinking_process - your_score + 0 or 1 Below are the user's question, reference answer and the assistant's response: @@ -60,14 +60,11 @@ [The End of the Assistant's Answer] """ -# PAIRWISE_WITH_SCORES_J1_PROMPT = """ -# You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. - -# First, provide your reasoning within and tags. This should include your evaluation criteria for a high-quality response, a detailed comparison of the two responses, and when helpful, a reference answer as part of your evaluation. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. -# Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +# PAIRWISE_WITH_SCORES_J1_PROMPT = """ +# You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -# Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +# Think carefully about how to assess the quality of the responses and assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . # Format your output like this: # your_thinking_process @@ -87,14 +84,15 @@ # [The End of Assistant B's Answer] # """ + PAIRWISE_WITH_SCORES_J1_PROMPT = """ You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +Think carefully about how to assess the quality of the responses and assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . Format your output like this: your_thinking_process - your_score_a your_score_b + 0 or 1 0 or 1 Below are the user's question and the two responses: @@ -110,63 +108,50 @@ [The End of Assistant B's Answer] """ -KWISE_WITH_SCORES_J1_PROMPT = """ -You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. - -Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. - -Format your output like this: - your_thinking_process - your_score_1 - your_score_2 - your_score_3 -... - -Below are the user's question and the two responses: - -[User Question] -{instruction} - -{responses} -""" +# PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +# You are given a user question, two responses from two AI assistants, and a reference answer. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ -You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +# Think carefully about how to assess the quality of the responses and utilize the reference answer for your judgement. Finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . -Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. +# Format your output like this: +# your_thinking_process +# your_score_a your_score_b -Format your output like this: - your_thinking_process - your_score_1 - your_score_2 - your_score_3 -... +# Below are the user's question, reference answer and the two responses: -Below are the user's question and the two responses: +# [User Question] +# {instruction} -[User Question] -{instruction} +# [Reference Answer] +# {reference_answer} -[Reference Answer] -{reference_answer} +# [The Start of Assistant A's Answer] +# {response_A} +# [The End of Assistant A's Answer] -{responses} -""" +# [The Start of Assistant B's Answer] +# {response_B} +# [The End of Assistant B's Answer] +# """ -# PAIRWISE_WITH_SCORES_J1_PROMPT = """ -# You are given a user question and two responses from two AI assistants. You are also given their thinking process. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Care any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +# PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +# You are given a user question, two responses from two AI assistants, and a reference answer. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -# Carefully analyze the assistants' thought process, assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +# Think carefully about how to assess the quality of the responses and utilize the reference answer for your judgement. Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . # Format your output like this: # your_thinking_process -# your_score_a your_score_b +# 0 or 1 0 or 1 -# Below are the user's question and the two responses: + +# Below are the user's question, reference answer and the two responses: # [User Question] # {instruction} +# [Reference Answer] +# {reference_answer} + # [The Start of Assistant A's Answer] # {response_A} # [The End of Assistant A's Answer] @@ -176,23 +161,22 @@ # [The End of Assistant B's Answer] # """ + PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ -You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +You are given a user question, two responses from two AI assistants and the parsed version of the responses, and a reference answer. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. Note that the parsed version of the responses are automatically extracted and may contain errors, therefore you should primarily rely on the original responses for your judgement. +Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . Format your output like this: your_thinking_process - your_score_a your_score_b + 0 or 1 0 or 1 -Below are the user's question, reference answer and the two responses: +Below are the user's question, two responses and the parsed versions of the responses, and the reference answer: [User Question] {instruction} -[Reference Answer] -{reference_answer} - [The Start of Assistant A's Answer] {response_A} [The End of Assistant A's Answer] @@ -200,6 +184,15 @@ [The Start of Assistant B's Answer] {response_B} [The End of Assistant B's Answer] + +[The Parsed Version of Assistant A's Answer] +{parsed_response_A} + +[The Parsed Version of Assistant B's Answer] +{parsed_response_B} + +[Reference Answer] +{reference_answer} """ import re @@ -475,6 +468,22 @@ def config_kls(self): class J1PairwiseScoreExtractor(JudgmentExtractor): def __init__(self, tokenizer): self.tokenizer = tokenizer + try: + from math_verify import parse + from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + NormalizationConfig, + ) + except ImportError: + raise ImportError( + "install mathverify from https://github.com/huggingface/Math-Verify" + ) + + self.student_extraction_config = ( + LatexExtractionConfig(boxed_match_priority=0), + ) + self.parse = parse @override def prompt(self, reference_answer): @@ -483,6 +492,17 @@ def prompt(self, reference_answer): if reference_answer is None else PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER ) + + def get_preferred_index(self, lst): + """ + math_verify parse returns a list of parsed answers, we want want the item at idex 1, which is a string + """ + if len(lst) > 1: + return lst[1] + elif len(lst) == 1: + return lst[0] + else: + return "None" @override def format_prompt( @@ -498,9 +518,11 @@ def format_prompt( if reference_answer is None else prompt_template.format( instruction=prompt_text, - reference_answer=reference_answer, response_A=rollout_A_text, response_B=rollout_B_text, + parsed_response_A=self.get_preferred_index(self.parse(rollout_A_text, self.student_extraction_config)), + parsed_response_B=self.get_preferred_index(self.parse(rollout_B_text, self.student_extraction_config)), + reference_answer=reference_answer, ) ) From 6703f1b978871c79ed25619eb6f38ee33b726f94 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Sun, 7 Sep 2025 20:36:21 +0000 Subject: [PATCH 26/33] update prompt template --- .../lm/_online_finetune/_generative_judge.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 8c4e0ca3a..c27a62a2f 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -195,6 +195,50 @@ {reference_answer} """ +KWISE_WITH_SCORES_J1_PROMPT = """ +You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + your_score_1 + your_score_2 + your_score_3 +... + +Below are the user's question and the responses: + +[User Question] +{instruction} + +{responses} +""" + +KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. +Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + 0 or 1 + 0 or 1 + 0 or 1 +... + +Below are the user's question, the reference answer, and the responses: + +[User Question] +{instruction} + +[Reference Answer] +{reference_answer} + +{responses} +""" + import re from abc import ABC, abstractmethod from typing import Any From fe84d9eabbaf960743eac4f59ab38548103fd5b2 Mon Sep 17 00:00:00 2001 From: chenxwh user Date: Wed, 10 Sep 2025 21:15:51 +0000 Subject: [PATCH 27/33] all comparisons in k-wise --- .../lm/_online_finetune/_generative_judge.py | 38 +++++++++++-- .../recipes/lm/_online_finetune/_rewards.py | 57 ++++++++++++++----- 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index c27a62a2f..af950264c 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -216,9 +216,9 @@ """ KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ -You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +You are given a user question, a reference answer, and {k} responses with the parsed versions from AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. Note that the parsed version of the responses are automatically extracted and may contain errors, therefore you should primarily rely on the original responses for your judgement. Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and so on. Format your output like this: @@ -228,7 +228,7 @@ 0 or 1 ... -Below are the user's question, the reference answer, and the responses: +Below are the user's question, reference answer, responses and the parsed versions of the responses: [User Question] {instruction} @@ -237,6 +237,8 @@ {reference_answer} {responses} + +{parsed_responses} """ import re @@ -629,6 +631,33 @@ class J1KwiseScoreExtractor(JudgmentExtractor): def __init__(self, tokenizer, k): self.tokenizer = tokenizer self.k = k + try: + from math_verify import parse + from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + NormalizationConfig, + ) + except ImportError: + raise ImportError( + "install mathverify from https://github.com/huggingface/Math-Verify" + ) + + self.student_extraction_config = ( + LatexExtractionConfig(boxed_match_priority=0), + ) + self.parse = parse + + def get_preferred_index(self, lst): + """ + math_verify parse returns a list of parsed answers, we want want the item at idex 1, which is a string + """ + if len(lst) > 1: + return lst[1] + elif len(lst) == 1: + return lst[0] + else: + return "None" @override def prompt(self, reference_answer): @@ -649,7 +678,8 @@ def format_prompt(self, prompt_text, rollouts, reference_answer): else prompt_template.format( k=self.k, instruction=prompt_text, - responses=rollouts, + responses="".join([f"[Start of Assistant {assistant_id+1}'s Answer]\n{rollout}\n[End of Assistant {assistant_id+1}'s Answer]\n\n" for assistant_id, rollout in enumerate(rollouts)]), + parsed_responses="".join([f"[The Parsed Version of Assistant {assistant_id+1}'s Answer]\n{self.get_preferred_index(self.parse(rollout, self.student_extraction_config))}\n\n" for assistant_id, rollout in enumerate(rollouts)]), reference_answer=reference_answer, ) ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 9977a71dd..d7566ffbf 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -1511,6 +1511,34 @@ def __init__( self.tokenizer, self.k ) + # def construct_all_k_tuples( + # self, + # prompt_text, + # i_batch_request_output, + # vllm_inputs, + # batch_kwise_indices, + # reference_answer, + # R, + # k, + # ): + # all_k_tuples = list(itertools.combinations(list(range(R)), k)) + # for k_tuple in all_k_tuples: + # k_list = list(k_tuple) + # random.shuffle(k_list) + # batch_kwise_indices.append(k_list) + # response_string = "" + # for assistant_id, idx in enumerate(k_list): + # rollout = i_batch_request_output.outputs[idx].text + # response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + # response_string = response_string.strip() + + # vllm_input = self.judgment_extractor.format_prompt( + # prompt_text, response_string, reference_answer + # ) + # vllm_inputs.append(vllm_input) + + # return vllm_inputs, batch_kwise_indices + def construct_all_k_tuples( self, prompt_text, @@ -1523,22 +1551,24 @@ def construct_all_k_tuples( ): all_k_tuples = list(itertools.combinations(list(range(R)), k)) for k_tuple in all_k_tuples: - k_list = list(k_tuple) - random.shuffle(k_list) - batch_kwise_indices.append(k_list) - response_string = "" - for assistant_id, idx in enumerate(k_list): - rollout = i_batch_request_output.outputs[idx].text - response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" - response_string = response_string.strip() + for k_list in itertools.permutations(k_tuple): + k_list = list(k_list) + batch_kwise_indices.append(k_list) + # response_string = "" + # for assistant_id, idx in enumerate(k_list): + # rollout = i_batch_request_output.outputs[idx].text + # response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + # response_string = response_string.strip() - vllm_input = self.judgment_extractor.format_prompt( - prompt_text, response_string, reference_answer - ) - vllm_inputs.append(vllm_input) + response_list = [i_batch_request_output.outputs[idx].text for idx in k_list] + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, response_list, reference_answer + ) + vllm_inputs.append(vllm_input) return vllm_inputs, batch_kwise_indices + def convert_kwise_rewards_to_pointwise( self, batch_kwise_rewards, @@ -1552,7 +1582,8 @@ def convert_kwise_rewards_to_pointwise( for prompt_idx in range(B): # Extract the kwise rewards for each input - num = math.comb(R, k) + # num = math.comb(R, k) + num = math.perm(R, k) idx_start, idx_end = ( prompt_idx * num, (prompt_idx + 1) * num, From e7137ac17031a11dc4e55a339c032c099bfc5593 Mon Sep 17 00:00:00 2001 From: Jack Lanchantin Date: Tue, 30 Sep 2025 17:27:10 -0400 Subject: [PATCH 28/33] Jacklanchantin/qwen (#1260) --- .../lm/_online_finetune/_online_dpo.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py index 0ef85f963..3547be2df 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py @@ -54,6 +54,11 @@ update_avg_rollout_length, update_batch_metrics, update_dpo_loss, + update_grpo_batch_metrics, + compute_reference_logps, + collate_with_target_mask, + update_avg_loss_zeroer, + strip_think_tokens, update_logit_entropy, ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler @@ -154,6 +159,8 @@ def validate_reward( ) if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") + + rollouts = strip_think_tokens(rollouts) reward_output = self._reward.process_rollouts(rollouts, prompt_batch) avg_reward = torch.tensor(reward_output["rewards"]).float().mean() @@ -200,6 +207,8 @@ def __call__( if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Train") + rollouts = strip_think_tokens(rollouts) + batch: PreferenceBatch batch, is_bad_batch, reward_output = self._reward.prepare_preference_batch( prompt_batch, rollouts @@ -455,6 +464,24 @@ def create( context=self._context, ) + + # TODO: decide converter as part of the model handler + if "llama" in model.name: + from fairseq2.models.llama._hg import _convert_parameter + + model._convert_parameter = _convert_parameter + else: + from fairseq2.models.qwen._hg import _convert_parameter + + model._convert_parameter = _convert_parameter + + # sync models here before we start training + if config.vllm_sync.sync_model_every_n_steps > 0: + maybe_sync_model(gangs, model, vllm_model, -1, -1, force_sync=True) + if config.vllm_sync.sync_ref_model_every_n_steps > 0: + maybe_sync_model(gangs, model, reference_model, -1, -1, force_sync=True) + + return OnlineDpoFinetuneUnit( model, reference_model, vllm_model, vllm_actors, reward, gangs, config ) From dfb958a65e3792fafc2229b22e7f06851ff29ec9 Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 30 Sep 2025 22:23:27 +0000 Subject: [PATCH 29/33] octothinker assets --- src/fairseq2/assets/cards/models/llama.yaml | 32 ++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/assets/cards/models/llama.yaml b/src/fairseq2/assets/cards/models/llama.yaml index 7570312a4..99dccb6b0 100644 --- a/src/fairseq2/assets/cards/models/llama.yaml +++ b/src/fairseq2/assets/cards/models/llama.yaml @@ -168,4 +168,34 @@ model_arch: llama3_1_8b checkpoint: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B" tokenizer: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B" tokenizer_family: llama -use_v2_tokenizer: true \ No newline at end of file +use_v2_tokenizer: true + +--- + +name: octothinker_8b_hybrid +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/ +tokenizer_family: llama +use_v2_tokenizer: true + +--- + +name: octothinker_8b_long +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/ +tokenizer_family: llama +use_v2_tokenizer: true + +--- + +name: octothinker_8b_short +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/ +tokenizer_family: llama +use_v2_tokenizer: true From a746129d0052f5ac8d97568ae7b8c1e828d9f47e Mon Sep 17 00:00:00 2001 From: swarna Date: Tue, 30 Sep 2025 22:26:48 +0000 Subject: [PATCH 30/33] Changes --- .../recipes/lm/_online_finetune/_common.py | 14 ++++- .../lm/_online_finetune/_generative_judge.py | 52 +++++++++++++--- .../recipes/lm/_online_finetune/_grpo.py | 7 +++ .../lm/_online_finetune/_remote_model.py | 2 +- .../recipes/lm/_online_finetune/_rewards.py | 62 ++++++++++++++----- src/fairseq2/setup/_po_finetune_units.py | 5 ++ 6 files changed, 117 insertions(+), 25 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 147f16876..a054b2e2e 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -425,9 +425,9 @@ def strip_think_tokens(rollouts: List[SequenceData]): if "" in rollout_text: think_present += 1 if rollout.finish_reason == "length": - count_stripped += 1 - if rollout.finish_reason == "stop": count_not_stripped += 1 + if rollout.finish_reason == "stop": + count_stripped += 1 total_count += 1 rollout.text = re.sub( r".*?", "", rollout_text, flags=re.DOTALL @@ -440,6 +440,12 @@ def strip_think_tokens(rollouts: List[SequenceData]): return rollouts +def get_failed_to_parse_answers(prompt_batch: PromptBatch): + if "answers" in prompt_batch: + failed_to_parse = sum(answer is None for rollouts in prompt_batch for answer in rollouts) + return failed_to_parse/prompt_batch.batch_size + else: + return 0.0 def format_think_tags(rollouts: List[SequenceData]): for sample in rollouts: @@ -545,6 +551,10 @@ def update_avg_reward(metric_bag: MetricBag, avg_reward): @torch.inference_mode() def update_std_reward(metric_bag: MetricBag, std_reward): metric_bag.get(Mean, "std_reward").update(std_reward, weight=1) + +@torch.inference_mode() +def update_failed_to_parse_answers(metric_bag: MetricBag, failed_to_parse_answers): + metric_bag.get(Mean, "failed_to_parse_answers").update(failed_to_parse_answers, weight=1) @torch.inference_mode() diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 9e82273d7..a862b1c75 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -177,22 +177,20 @@ # """ PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ -You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +You are given a user question, two responses from two AI assistants and the parsed version of the responses, and a reference answer. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. Note that the parsed version of the responses are automatically extracted and may contain errors, therefore you should primarily rely on the original responses for your judgement. +Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . Format your output like this: your_thinking_process - your_score_a your_score_b + 0 or 1 0 or 1 -Below are the user's question, reference answer and the two responses: +Below are the user's question, two responses and the parsed versions of the responses, and the reference answer: [User Question] {instruction} -[Reference Answer] -{reference_answer} - [The Start of Assistant A's Answer] {response_A} [The End of Assistant A's Answer] @@ -200,6 +198,15 @@ [The Start of Assistant B's Answer] {response_B} [The End of Assistant B's Answer] + +[The Parsed Version of Assistant A's Answer] +{parsed_response_A} + +[The Parsed Version of Assistant B's Answer] +{parsed_response_B} + +[Reference Answer] +{reference_answer} """ import re @@ -475,6 +482,22 @@ def config_kls(self): class J1PairwiseScoreExtractor(JudgmentExtractor): def __init__(self, tokenizer): self.tokenizer = tokenizer + try: + from math_verify import parse + from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + NormalizationConfig, + ) + except ImportError: + raise ImportError( + "install mathverify from https://github.com/huggingface/Math-Verify" + ) + + self.student_extraction_config = ( + LatexExtractionConfig(boxed_match_priority=0), + ) + self.parse = parse @override def prompt(self, reference_answer): @@ -483,6 +506,17 @@ def prompt(self, reference_answer): if reference_answer is None else PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER ) + + def get_preferred_index(self, lst): + """ + math_verify parse returns a list of parsed answers, we want want the item at idex 1, which is a string + """ + if len(lst) > 1: + return lst[1] + elif len(lst) == 1: + return lst[0] + else: + return "None" @override def format_prompt( @@ -498,9 +532,11 @@ def format_prompt( if reference_answer is None else prompt_template.format( instruction=prompt_text, - reference_answer=reference_answer, response_A=rollout_A_text, response_B=rollout_B_text, + parsed_response_A=self.get_preferred_index(self.parse(rollout_A_text, self.student_extraction_config)), + parsed_response_B=self.get_preferred_index(self.parse(rollout_B_text, self.student_extraction_config)), + reference_answer=reference_answer, ) ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 40d01bf1d..e79b6d508 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -38,6 +38,7 @@ generate_rollouts, get_rollout_lengths, log_rollouts, + get_failed_to_parse_answers, strip_think_tokens, update_avg_reward, update_avg_reward_len_norm, @@ -47,6 +48,7 @@ update_grpo_loss, update_logit_entropy, update_std_reward, + update_failed_to_parse_answers ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler from fairseq2.recipes.lm._online_finetune._remote_model import ( @@ -224,6 +226,7 @@ def validate_reward( log.info(f"Rewards: {reward_output['rewards']}") avg_reward = torch.tensor(reward_output["rewards"]).float().mean() std_reward = torch.tensor(reward_output["rewards"]).float().std() + failed_to_parse_answers = get_failed_to_parse_answers(prompt_batch) rollout_lengths = get_rollout_lengths(rollouts) avg_rollout_length = torch.tensor(rollout_lengths).float().mean() @@ -235,6 +238,7 @@ def validate_reward( update_avg_reward(metric_bag, avg_reward) update_batch_metrics(metric_bag, prompt_batch, train=False) update_std_reward(metric_bag, std_reward) + update_failed_to_parse_answers(metric_bag, failed_to_parse_answers) # returning dummy loss since trainer expects it return torch.tensor(0.0, device=self._gangs.dp.device), prompt_batch.batch_size @@ -367,6 +371,9 @@ def __call__( update_std_reward(metric_bag, std_reward) update_avg_reward(metric_bag, avg_reward) + + failed_to_parse_answers = get_failed_to_parse_answers(prompt_batch) + update_failed_to_parse_answers(metric_bag, failed_to_parse_answers) loss = grpo_loss diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index 8e10249d4..a4060670f 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -585,7 +585,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None, string_input=Fal "RemoteHFModel.rollout_from_model is not implemented. " ) - def reward_from_model(self, prompt_list, batch_size=4): + def reward_from_model(self, prompt_list, batch_size=2): # NOTE: need to batch inputs to hf.encode model for current models that aren't supported by hf rewards = [] outputs = [] diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index 9977a71dd..ef023cb21 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -254,6 +254,7 @@ def process_rollouts( batch_text = [] batch_tokens = [] batch_rewards = [] + batch_answers = [] reference_answers = prompt_batch.meta_info.get(self.answer_key) @@ -262,6 +263,7 @@ def process_rollouts( rollouts_tokens = [] i_reference_answer = reference_answers[i] rollouts_rewards = [] + rollouts_answers = [] for rollout_output in i_batch_request_output.outputs: rollouts_text.append(rollout_output.text) rollouts_tokens.append(rollout_output.token_ids) @@ -269,11 +271,13 @@ def process_rollouts( rollout_output.text, i_reference_answer ) rollouts_rewards.append(predicted_reward) + rollouts_answers.append(predicted_answer) batch_text.append(rollouts_text) batch_tokens.append(rollouts_tokens) batch_rewards.append(rollouts_rewards) + batch_answers.append(rollouts_answers) - return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards, "answers": batch_answers} def prepare_preference_batch( self, prompt_batch: PromptBatch, rollouts @@ -1511,6 +1515,34 @@ def __init__( self.tokenizer, self.k ) + # def construct_all_k_tuples( + # self, + # prompt_text, + # i_batch_request_output, + # vllm_inputs, + # batch_kwise_indices, + # reference_answer, + # R, + # k, + # ): + # all_k_tuples = list(itertools.combinations(list(range(R)), k)) + # for k_tuple in all_k_tuples: + # k_list = list(k_tuple) + # random.shuffle(k_list) + # batch_kwise_indices.append(k_list) + # response_string = "" + # for assistant_id, idx in enumerate(k_list): + # rollout = i_batch_request_output.outputs[idx].text + # response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + # response_string = response_string.strip() + + # vllm_input = self.judgment_extractor.format_prompt( + # prompt_text, response_string, reference_answer + # ) + # vllm_inputs.append(vllm_input) + + # return vllm_inputs, batch_kwise_indices + def construct_all_k_tuples( self, prompt_text, @@ -1523,20 +1555,21 @@ def construct_all_k_tuples( ): all_k_tuples = list(itertools.combinations(list(range(R)), k)) for k_tuple in all_k_tuples: - k_list = list(k_tuple) - random.shuffle(k_list) - batch_kwise_indices.append(k_list) - response_string = "" - for assistant_id, idx in enumerate(k_list): - rollout = i_batch_request_output.outputs[idx].text - response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" - response_string = response_string.strip() + for k_list in itertools.permutations(k_tuple): + k_list = list(k_list) + batch_kwise_indices.append(k_list) + response_string = "" + for assistant_id, idx in enumerate(k_list): + rollout = i_batch_request_output.outputs[idx].text + response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + response_string = response_string.strip() - vllm_input = self.judgment_extractor.format_prompt( - prompt_text, response_string, reference_answer - ) - vllm_inputs.append(vllm_input) + # response_list = [i_batch_request_output.outputs[idx].text for idx in k_list] + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, response_string, reference_answer + ) + vllm_inputs.append(vllm_input) return vllm_inputs, batch_kwise_indices def convert_kwise_rewards_to_pointwise( @@ -1552,7 +1585,8 @@ def convert_kwise_rewards_to_pointwise( for prompt_idx in range(B): # Extract the kwise rewards for each input - num = math.comb(R, k) + # num = math.comb(R, k) + num = math.perm(R, k) idx_start, idx_end = ( prompt_idx * num, (prompt_idx + 1) * num, diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index 83890a060..b52b880b3 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -29,6 +29,7 @@ NoEnvAtheneRewardPipeline, NoEnvGeneralVerifierPipeline, NoEnvAceMathRMPipeline, + NoEnvSkyworkRMPipeline, OnlineDpoFinetuneUnitHandler, OnlineFinetuneUnitHandler, OrpoFinetuneUnitHandler, @@ -133,6 +134,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # NoEnvAceMathRMPipeline handler = NoEnvAceMathRMPipeline registry.register(handler.name, handler) + + # NoEnvAceMathRMPipeline + handler = NoEnvSkyworkRMPipeline + registry.register(handler.name, handler) # Generative judgment extractors registry = context.get_registry(JudgmentExtractorHandler) From 9355ce6906db0646f7a171eab4df8ef3d4ad6cec Mon Sep 17 00:00:00 2001 From: swarnadeep user Date: Thu, 16 Oct 2025 15:41:29 +0000 Subject: [PATCH 31/33] Minor changes --- .../recipes/lm/_online_finetune/_common.py | 19 ++++++++++--------- .../recipes/lm/_online_finetune/_grpo.py | 13 +++++++------ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index a054b2e2e..20e4448f0 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -440,20 +440,21 @@ def strip_think_tokens(rollouts: List[SequenceData]): return rollouts -def get_failed_to_parse_answers(prompt_batch: PromptBatch): - if "answers" in prompt_batch: - failed_to_parse = sum(answer is None for rollouts in prompt_batch for answer in rollouts) - return failed_to_parse/prompt_batch.batch_size +def get_failed_to_parse_answers(reward_output: dict, batch_size: int): + if "answers" in reward_output: + log.info(f"Answers: {reward_output['answers']}") + failed_to_parse = sum(answer is None for rollouts in reward_output["answers"] for answer in rollouts) + return failed_to_parse/batch_size else: return 0.0 - -def format_think_tags(rollouts: List[SequenceData]): + +def strip_for_octothinker(rollouts: List[SequenceData]): for sample in rollouts: for rollout in sample.outputs: rollout_text = rollout.text - rollout.text = rollout_text.replace( - "", "[Start of Assistant Thinking]" - ).replace("", "[End of Assistant Thinking]") + if "\nUser:" in rollout_text: + rollout_text = rollout_text[:rollout_text.find("\nUser:")] + rollout.text = rollout_text return rollouts diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index e79b6d508..87a5dd3fb 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -34,7 +34,7 @@ collate_with_target_mask, compute_reference_logps, compute_token_level_entropy, - format_think_tags, + strip_for_octothinker, generate_rollouts, get_rollout_lengths, log_rollouts, @@ -215,7 +215,7 @@ def validate_reward( if self._config.reward.config.strip_thinking: rollouts = strip_think_tokens(rollouts) else: - rollouts = format_think_tags(rollouts) + rollouts = strip_for_octothinker(rollouts) log.info("After stripping") if self._config.loss_config.log_rollouts: @@ -226,7 +226,7 @@ def validate_reward( log.info(f"Rewards: {reward_output['rewards']}") avg_reward = torch.tensor(reward_output["rewards"]).float().mean() std_reward = torch.tensor(reward_output["rewards"]).float().std() - failed_to_parse_answers = get_failed_to_parse_answers(prompt_batch) + failed_to_parse_answers = get_failed_to_parse_answers(reward_output, prompt_batch.batch_size) rollout_lengths = get_rollout_lengths(rollouts) avg_rollout_length = torch.tensor(rollout_lengths).float().mean() @@ -284,9 +284,10 @@ def __call__( if self._config.reward.config.strip_thinking: rollouts = strip_think_tokens(rollouts) else: - rollouts = format_think_tags(rollouts) + rollouts = strip_for_octothinker(rollouts) log.info("After stripping") - log_rollouts(prompt_batch, rollouts, "Train") + if self._config.loss_config.log_rollouts: + log_rollouts(prompt_batch, rollouts, "Train") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) self._rollout_bag.save(rollouts, reward_output) @@ -372,7 +373,7 @@ def __call__( update_std_reward(metric_bag, std_reward) update_avg_reward(metric_bag, avg_reward) - failed_to_parse_answers = get_failed_to_parse_answers(prompt_batch) + failed_to_parse_answers = get_failed_to_parse_answers(reward_output, prompt_batch.batch_size) update_failed_to_parse_answers(metric_bag, failed_to_parse_answers) loss = grpo_loss From 474a537295127db2be88fb6393ac121a0d2b4a24 Mon Sep 17 00:00:00 2001 From: swarnadeep user Date: Wed, 22 Oct 2025 21:43:21 +0000 Subject: [PATCH 32/33] Logging judge input --- src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 952b94727..235ac0d8f 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -427,6 +427,7 @@ def format_prompt(self, prompt_text, rollout_text, reference_answer): chat_str = self.tokenizer.apply_chat_template( wrapped_text, tokenize=False, add_generation_prompt=True ) + log.info(f"Judge input = {chat_str}") return chat_str @override From 821166419b19cd847915b950460759dfd5e71d7c Mon Sep 17 00:00:00 2001 From: swarnadeep user Date: Wed, 29 Oct 2025 01:47:02 +0000 Subject: [PATCH 33/33] Tracking a second reward (for debugging) --- .../recipes/lm/_online_finetune/_common.py | 8 +++++ .../lm/_online_finetune/_generative_judge.py | 2 +- .../recipes/lm/_online_finetune/_grpo.py | 35 ++++++++++++++++++- .../lm/_online_finetune/_remote_model.py | 1 + src/fairseq2/setup/_metrics.py | 2 ++ 5 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index 20e4448f0..399efe369 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -547,6 +547,14 @@ def update_num_dummy_batches( @torch.inference_mode() def update_avg_reward(metric_bag: MetricBag, avg_reward): metric_bag.get(Mean, "avg_reward").update(avg_reward, weight=1) + +@torch.inference_mode() +def update_avg_second_reward(metric_bag: MetricBag, avg_reward): + metric_bag.get(Mean, "avg_second_reward").update(avg_reward, weight=1) + +@torch.inference_mode() +def update_reward_matches(metric_bag: MetricBag, reward_matches): + metric_bag.get(Mean, "reward_matches").update(reward_matches, weight=1) @torch.inference_mode() diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index 235ac0d8f..8c7dabdc6 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -427,7 +427,7 @@ def format_prompt(self, prompt_text, rollout_text, reference_answer): chat_str = self.tokenizer.apply_chat_template( wrapped_text, tokenize=False, add_generation_prompt=True ) - log.info(f"Judge input = {chat_str}") + # log.info(f"Judge input = {chat_str}") return chat_str @override diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 87a5dd3fb..d5493843c 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -41,6 +41,8 @@ get_failed_to_parse_answers, strip_think_tokens, update_avg_reward, + update_avg_second_reward, + update_reward_matches, update_avg_reward_len_norm, update_avg_rollout_length, update_batch_metrics, @@ -141,6 +143,7 @@ class GrpoFinetuneUnit(TrainUnit[SequenceBatch]): _config: GrpoFinetuneConfig _model_update_group: PyNcclCommunicator _reward: VLLMOutputReward + _second_reward: VLLMOutputReward _display_name: str _rollout_bag: StatefulRolloutBag @@ -151,6 +154,7 @@ def __init__( vllm_model: RemoteVllmModel, vllm_actors: List[Union[RemoteVllmModel, RemoteHFModel]], reward, + second_reward, gangs: Gangs, config: GrpoFinetuneConfig, ) -> None: @@ -162,6 +166,7 @@ def __init__( self._vllm_model = vllm_model self._gangs = gangs self._reward = reward + self._second_reward = second_reward self._rollout_bag = StatefulRolloutBag( max_bag_steps=int( config.loss_config.group_size / config.loss_config.forward_group_size @@ -227,6 +232,15 @@ def validate_reward( avg_reward = torch.tensor(reward_output["rewards"]).float().mean() std_reward = torch.tensor(reward_output["rewards"]).float().std() failed_to_parse_answers = get_failed_to_parse_answers(reward_output, prompt_batch.batch_size) + + second_reward_output = self._second_reward.process_rollouts(rollouts, prompt_batch) + log.info(f"Second Rewards: {second_reward_output['rewards']}") + avg_second_reward = torch.tensor(second_reward_output["rewards"]).float().mean() + update_avg_second_reward(metric_bag, avg_second_reward) + + reward_matches = (torch.tensor(reward_output["rewards"]) == torch.tensor(second_reward_output["rewards"])).all(dim=1).float().mean() + log.info(f"Reward matches: {reward_matches}") + update_reward_matches(metric_bag, reward_matches) rollout_lengths = get_rollout_lengths(rollouts) avg_rollout_length = torch.tensor(rollout_lengths).float().mean() @@ -492,6 +506,9 @@ class GrpoFinetuneConfig: vllm_reward_model_actor_name: str | None = None """Optional name of the Ray vLLM actor used as a reward model.""" + + vllm_second_reward_model_actor_name: str | None = None + """Optional name of the Ray vLLM actor used as a reward model.""" vllm_reference_model_actor_name: str | None = None """Optional name of the Ray vLLM actor used as a reference model.""" @@ -500,6 +517,10 @@ class GrpoFinetuneConfig: default_factory=lambda: RewardSection(name="gsm8k_verifier") ) """Configuration for the reward function that evaluates generated rollouts.""" + + second_reward: RewardSection = field( + default_factory=lambda: RewardSection(name="gsm8k_verifier") + ) vllm_sync: VllmSyncSection = field(default_factory=lambda: VllmSyncSection()) @@ -539,6 +560,8 @@ def create( vllm_model.sampling_params.n = config.loss_config.group_size vllm_reward_model = vllm_actors.get(config.vllm_reward_model_actor_name, None) + vllm_second_reward_model = vllm_actors.get(config.vllm_second_reward_model_actor_name, None) + reward_registry = self._context.get_registry(VLLMOutputRewardHandler) reward_name = config.reward.name reward_handler = reward_registry.get(reward_name) @@ -549,6 +572,16 @@ def create( gangs=gangs, context=self._context, ) + + second_reward_name = config.second_reward.name + second_reward_handler = reward_registry.get(second_reward_name) + second_reward = second_reward_handler.create( + reward_model=vllm_second_reward_model, + reward_name=second_reward_name, + reward_config=config.second_reward.config, + gangs=gangs, + context=self._context, + ) # sync models here before we start training if config.vllm_sync.sync_model_every_n_steps > 0: @@ -559,7 +592,7 @@ def create( log.info("GRPO setup complete.") return GrpoFinetuneUnit( - model, reference_model, vllm_model, vllm_actors, reward, gangs, config + model, reference_model, vllm_model, vllm_actors, reward, second_reward, gangs, config ) @property diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index a4060670f..560c885f7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -93,6 +93,7 @@ def __init__(self, *args, **kwargs): # at the top-level del os.environ["CUDA_VISIBLE_DEVICES"] # os.environ["VLLM_USE_V1"] = "1" + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" super().__init__(*args, **kwargs) self.ready = True # Set a flag or return a signal diff --git a/src/fairseq2/setup/_metrics.py b/src/fairseq2/setup/_metrics.py index 0a38a92f2..071d78044 100644 --- a/src/fairseq2/setup/_metrics.py +++ b/src/fairseq2/setup/_metrics.py @@ -79,6 +79,8 @@ def register(name: str, *args: Any, **kwargs: Any) -> None: register("simpo_loss", "SimPO Loss", 0, format_as_float) register("grpo_loss", "GRPO Loss", 0, format_as_float) register("avg_reward", "Reward", 1, format_as_float) + register("avg_second_reward", "Second Reward", 1, format_as_float) + register("reward_matches", "Reward Matches", 1, format_as_float) register("std_reward", "StdDev Reward", 1, format_as_float) register("avg_reward_len_norm","Length Normalized Reward", 1, format_as_float) register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float)