Skip to content

Commit d47ef15

Browse files
author
swarna
committed
kwise judgment support
1 parent b1ba0e2 commit d47ef15

File tree

8 files changed

+524
-45
lines changed

8 files changed

+524
-45
lines changed

src/fairseq2/recipes/lm/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
from fairseq2.recipes.lm._online_finetune._generative_judge import (
4949
J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler,
5050
)
51+
from fairseq2.recipes.lm._online_finetune._generative_judge import (
52+
J1KwiseScoreExtractor as J1KwiseScoreExtractor,
53+
)
54+
from fairseq2.recipes.lm._online_finetune._generative_judge import (
55+
J1KwiseScoreExtractorHandler as J1KwiseScoreExtractorHandler,
56+
)
5157
from fairseq2.recipes.lm._online_finetune._generative_judge import (
5258
J1PointwiseExtractor as J1PointwiseExtractor,
5359
)
@@ -111,6 +117,12 @@
111117
from fairseq2.recipes.lm._online_finetune._rewards import (
112118
GenerativePointwiseVerifierHandler as GenerativePointwiseVerifierHandler,
113119
)
120+
from fairseq2.recipes.lm._online_finetune._rewards import (
121+
GenerativeKwiseVerifier as GenerativeKwiseVerifier,
122+
)
123+
from fairseq2.recipes.lm._online_finetune._rewards import (
124+
GenerativeKwiseVerifierHandler as GenerativeKwiseVerifierHandler,
125+
)
114126
from fairseq2.recipes.lm._online_finetune._rewards import GSM8kVerifier as GSM8kVerifier
115127
from fairseq2.recipes.lm._online_finetune._rewards import (
116128
GSM8kVerifierHandler as GSM8kVerifierHandler,

src/fairseq2/recipes/lm/_online_finetune/_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,17 @@ def strip_think_tokens(rollouts: List[SequenceData]):
439439
return rollouts
440440

441441

442+
def format_think_tags(rollouts: List[SequenceData]):
443+
for sample in rollouts:
444+
for rollout in sample.outputs:
445+
rollout_text = rollout.text
446+
rollout.text = rollout_text.replace(
447+
"<think>", "[Start of Assistant Thinking]"
448+
).replace("</think>", "[End of Assistant Thinking]")
449+
450+
return rollouts
451+
452+
442453
class StatefulRolloutBag:
443454
"""A stateful container for managing and reusing model rollouts across multiple micro-batches.
444455

src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# """
2121

2222
POINTWISE_J1_PROMPT = """
23-
You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently.
23+
You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. Do not allow the length of the response to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
2424
2525
Think carefully about how to assess the quality of the response and finally assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within <score> and </score> tags.
2626
@@ -110,6 +110,72 @@
110110
[The End of Assistant B's Answer]
111111
"""
112112

113+
KWISE_WITH_SCORES_J1_PROMPT = """
114+
You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
115+
116+
Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags <score_assistant_1> </score_assistant_1>, <score_assistant_2> </score_assistant_2> and so on.
117+
118+
Format your output like this:
119+
<think> your_thinking_process </think>
120+
<score_assistant_1> your_score_1 </score_assistant_1>
121+
<score_assistant_2> your_score_2 </score_assistant_2>
122+
<score_assistant_3> your_score_3 </score_assistant_3>
123+
...
124+
125+
Below are the user's question and the two responses:
126+
127+
[User Question]
128+
{instruction}
129+
130+
{responses}
131+
"""
132+
133+
KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """
134+
You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
135+
136+
Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags <score_assistant_1> </score_assistant_1>, <score_assistant_2> </score_assistant_2> and so on.
137+
138+
Format your output like this:
139+
<think> your_thinking_process </think>
140+
<score_assistant_1> your_score_1 </score_assistant_1>
141+
<score_assistant_2> your_score_2 </score_assistant_2>
142+
<score_assistant_3> your_score_3 </score_assistant_3>
143+
...
144+
145+
Below are the user's question and the two responses:
146+
147+
[User Question]
148+
{instruction}
149+
150+
[Reference Answer]
151+
{reference_answer}
152+
153+
{responses}
154+
"""
155+
156+
# PAIRWISE_WITH_SCORES_J1_PROMPT = """
157+
# You are given a user question and two responses from two AI assistants. You are also given their thinking process. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Care any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
158+
159+
# Carefully analyze the assistants' thought process, assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags <score_A> </score_A>, and <score_B> </score_B>.
160+
161+
# Format your output like this:
162+
# <think> your_thinking_process </think>
163+
# <score_A> your_score_a </score_A> <score_B> your_score_b </score_B>
164+
165+
# Below are the user's question and the two responses:
166+
167+
# [User Question]
168+
# {instruction}
169+
170+
# [The Start of Assistant A's Answer]
171+
# {response_A}
172+
# [The End of Assistant A's Answer]
173+
174+
# [The Start of Assistant B's Answer]
175+
# {response_B}
176+
# [The End of Assistant B's Answer]
177+
# """
178+
113179
PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """
114180
You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
115181
@@ -147,18 +213,15 @@
147213

148214
class JudgmentExtractorHandler(ABC):
149215
@abstractmethod
150-
def create(self, tokenizer):
151-
...
216+
def create(self, tokenizer): ...
152217

153218
@property
154219
@abstractmethod
155-
def name(self) -> str:
156-
...
220+
def name(self) -> str: ...
157221

158222
@property
159223
@abstractmethod
160-
def config_kls(self) -> type[object]:
161-
...
224+
def config_kls(self) -> type[object]: ...
162225

163226

164227
"""
@@ -177,12 +240,10 @@ class JudgmentExtractor(ABC):
177240
"""
178241

179242
@abstractmethod
180-
def prompt(self) -> str:
181-
...
243+
def prompt(self) -> str: ...
182244

183245
@abstractmethod
184-
def format_prompt(self, prompt_text, **kwargs: Any) -> str:
185-
...
246+
def format_prompt(self, prompt_text, **kwargs: Any) -> str: ...
186247

187248
"""
188249
Format the prompt text and additional arguments into a string suitable for input to the reward model.
@@ -195,8 +256,7 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str:
195256
"""
196257

197258
@abstractmethod
198-
def extract(self, generation) -> float | str:
199-
...
259+
def extract(self, generation) -> float | str: ...
200260

201261
"""
202262
Extract the final scalar reward score from the model's response.
@@ -215,8 +275,7 @@ def extract(self, generation) -> float | str:
215275
"""
216276

217277
@abstractmethod
218-
def aggregate(self, judgments) -> float | str:
219-
...
278+
def aggregate(self, judgments) -> float | str: ...
220279

221280
"""
222281
Aggregate multiple responses (judgments) from the reward model into a single value.
@@ -472,3 +531,83 @@ def aggregate(self, judgments):
472531
round(avg_score[0] / len(judgments), 4),
473532
round(avg_score[1] / len(judgments), 4),
474533
)
534+
535+
536+
class J1KwiseScoreExtractorHandler(JudgmentExtractorHandler):
537+
def __init__(self):
538+
pass
539+
540+
@override
541+
def create(self, tokenizer, k):
542+
return J1KwiseScoreExtractor(tokenizer, k)
543+
544+
@property
545+
@override
546+
def name(self):
547+
return "j1_kwise_score_extractor"
548+
549+
@property
550+
@override
551+
def config_kls(self):
552+
return None
553+
554+
555+
class J1KwiseScoreExtractor(JudgmentExtractor):
556+
def __init__(self, tokenizer, k):
557+
self.tokenizer = tokenizer
558+
self.k = k
559+
560+
@override
561+
def prompt(self, reference_answer):
562+
return (
563+
KWISE_WITH_SCORES_J1_PROMPT
564+
if reference_answer is None
565+
else KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER
566+
)
567+
568+
@override
569+
def format_prompt(self, prompt_text, rollouts, reference_answer):
570+
prompt_template = self.prompt(reference_answer)
571+
content = (
572+
prompt_template.format(
573+
k=self.k, instruction=prompt_text, responses=rollouts
574+
)
575+
if reference_answer is None
576+
else prompt_template.format(
577+
k=self.k,
578+
instruction=prompt_text,
579+
responses=rollouts,
580+
reference_answer=reference_answer,
581+
)
582+
)
583+
584+
wrapped_text = [{"role": "user", "content": content}]
585+
chat_str = self.tokenizer.apply_chat_template(
586+
wrapped_text, tokenize=False, add_generation_prompt=True
587+
)
588+
return chat_str
589+
590+
@override
591+
def extract(self, generation):
592+
scores = []
593+
for i in range(self.k):
594+
score_matches = re.findall(
595+
rf"<score_assistant_{i+1}>\s*([0-9]+(?:\.[0-9])?)\s*(?:/10)?\s*</score_assistant_{i+1}>",
596+
generation,
597+
)
598+
if score_matches:
599+
scores.append(float(score_matches[-1].strip()))
600+
else:
601+
scores.append(0.0)
602+
603+
return scores
604+
605+
@override
606+
def aggregate(self, judgments):
607+
avg_score = [0.0] * self.k
608+
for scores in judgments:
609+
for i, score in enumerate(scores):
610+
avg_score[i] += score
611+
612+
avg_score = [round(avg_score[i] / len(judgments), 4) for i in range(self.k)]
613+
return avg_score

src/fairseq2/recipes/lm/_online_finetune/_grpo.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
collate_with_target_mask,
3535
compute_reference_logps,
3636
compute_token_level_entropy,
37+
format_think_tags,
3738
generate_rollouts,
3839
get_rollout_lengths,
3940
log_rollouts,
@@ -195,9 +196,12 @@ def validate_reward(
195196
policy_sampling_params.__setattr__(k, v)
196197

197198
# For a pairwise RM, need to sample at least two rollouts
198-
policy_sampling_params.n = (
199-
2 if self._reward.reward_name == "generative_pairwise_verifier" else 1
200-
)
199+
if self._reward.reward_name == "generative_pairwise_verifier":
200+
policy_sampling_params.n = 2
201+
elif self._reward.reward_name == "generative_kwise_verifier":
202+
policy_sampling_params.n = self._config.reward.config.k
203+
else:
204+
policy_sampling_params.n = 1
201205
else:
202206
policy_sampling_params = None
203207
rollouts = generate_rollouts(
@@ -206,7 +210,11 @@ def validate_reward(
206210
vllm_model=self._vllm_model,
207211
sampling_params=policy_sampling_params,
208212
)
209-
rollouts = strip_think_tokens(rollouts)
213+
if self._config.reward.config.strip_thinking:
214+
rollouts = strip_think_tokens(rollouts)
215+
else:
216+
rollouts = format_think_tags(rollouts)
217+
210218
log.info("After stripping")
211219
if self._config.loss_config.log_rollouts:
212220
log_rollouts(prompt_batch, rollouts, "Valid")
@@ -269,8 +277,11 @@ def __call__(
269277
# if self._config.loss_config.log_rollouts:
270278
# log_rollouts(prompt_batch, rollouts, "Train")
271279

272-
rollouts = strip_think_tokens(rollouts)
273-
log.info('After stripping')
280+
if self._config.reward.config.strip_thinking:
281+
rollouts = strip_think_tokens(rollouts)
282+
else:
283+
rollouts = format_think_tags(rollouts)
284+
log.info("After stripping")
274285
log_rollouts(prompt_batch, rollouts, "Train")
275286
reward_output = self._reward.process_rollouts(rollouts, prompt_batch)
276287
self._rollout_bag.save(rollouts, reward_output)

src/fairseq2/recipes/lm/_online_finetune/_handler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,15 @@ class OnlineFinetuneUnitHandler(ABC):
1919
@abstractmethod
2020
def create(
2121
self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object
22-
) -> TrainUnit[SequenceBatch]:
23-
...
22+
) -> TrainUnit[SequenceBatch]: ...
2423

2524
@property
2625
@abstractmethod
27-
def name(self) -> str:
28-
...
26+
def name(self) -> str: ...
2927

3028
@property
3129
@abstractmethod
32-
def config_kls(self) -> type[object]:
33-
...
30+
def config_kls(self) -> type[object]: ...
3431

3532

3633
class UnknownOnlineFinetuneUnitError(Exception):

src/fairseq2/recipes/lm/_online_finetune/_remote_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -564,18 +564,15 @@ class RemoteModelHandler(ABC):
564564
@abstractmethod
565565
def create(
566566
self, gangs: Gangs, unit_config: object
567-
) -> Union[RemoteVllmModel, RemoteHFModel]:
568-
...
567+
) -> Union[RemoteVllmModel, RemoteHFModel]: ...
569568

570569
@property
571570
@abstractmethod
572-
def name(self) -> str:
573-
...
571+
def name(self) -> str: ...
574572

575573
@property
576574
@abstractmethod
577-
def config_kls(self) -> type[object]:
578-
...
575+
def config_kls(self) -> type[object]: ...
579576

580577

581578
class RemoteRayModelHandler(RemoteModelHandler):

0 commit comments

Comments
 (0)