Skip to content

Commit d474070

Browse files
author
swarna
committed
Adding support for acemath
1 parent d47ef15 commit d474070

File tree

6 files changed

+258
-1
lines changed

6 files changed

+258
-1
lines changed

src/fairseq2/recipes/lm/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@
9090
from fairseq2.recipes.lm._online_finetune._remote_model import (
9191
NoEnvGeneralVerifierPipeline as NoEnvGeneralVerifierPipeline,
9292
)
93+
from fairseq2.recipes.lm._online_finetune._remote_model import (
94+
NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline,
95+
)
9396
from fairseq2.recipes.lm._online_finetune._remote_model import (
9497
RemoteModelHandler as RemoteModelHandler,
9598
)
@@ -105,6 +108,12 @@
105108
from fairseq2.recipes.lm._online_finetune._rewards import (
106109
SkyworkVerifierHandler as SkyworkVerifierHandler,
107110
)
111+
from fairseq2.recipes.lm._online_finetune._rewards import (
112+
AceMathVerifier as AceMathVerifier,
113+
)
114+
from fairseq2.recipes.lm._online_finetune._rewards import (
115+
AceMathVerifierHandler as AceMathVerifierHandler,
116+
)
108117
from fairseq2.recipes.lm._online_finetune._rewards import (
109118
GenerativePairwiseVerifier as GenerativePairwiseVerifier,
110119
)

src/fairseq2/recipes/lm/_online_finetune/_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ def log_rollouts(prompt_batch: PromptBatch, rollouts, split_name, num_rollouts=1
394394
prompt = prompt_batch.meta_info.get("prompt_raw")[0]
395395
elif "raw_prompt" in prompt_batch.meta_info:
396396
prompt = prompt_batch.meta_info.get("raw_prompt")[0]
397+
elif "problem" in prompt_batch.meta_info:
398+
prompt = prompt_batch.meta_info.get("problem")[0]
397399
else:
398400
# raw text prompt doesn't exist for this dataset
399401
prompt = "DUMMY PROMPT"

src/fairseq2/recipes/lm/_online_finetune/_remote_model.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from fairseq2.recipes.lm._online_finetune.third_party.general_verifier import (
3232
GeneralVerifierPipeline,
3333
)
34+
from fairseq2.recipes.lm._online_finetune.third_party.ace_math import (
35+
AceMathRMPipeline,
36+
)
3437
from fairseq2.utils.structured import StructureError, structure
3538

3639

@@ -48,6 +51,7 @@ class VllmEngineArgs:
4851
tokenizer: str = "/datasets/pretrained-llms/Llama-3.1-8B-Instruct"
4952
task: str = "generate"
5053
tensor_parallel_size: int = 4
54+
max_model_len: int | None = None
5155
trust_remote_code: bool = False
5256
model_impl: str = "auto"
5357
enforce_eager: bool = True
@@ -136,6 +140,26 @@ def is_ready(self):
136140
@property
137141
def name(self):
138142
return "general_verifier_pipeline"
143+
144+
@ray.remote
145+
class NoEnvAceMathRMPipeline(AceMathRMPipeline):
146+
"""
147+
This is for running Ace Math RM pipeline with HF backend.
148+
"""
149+
150+
def __init__(self, *args, **kwargs):
151+
# stop ray from manipulating CUDA_VISIBLE_DEVICES
152+
# at the top-level
153+
del os.environ["CUDA_VISIBLE_DEVICES"]
154+
super().__init__(*args, **kwargs)
155+
self.ready = True # Set a flag or return a signal
156+
157+
def is_ready(self):
158+
return self.ready
159+
160+
@property
161+
def name(self):
162+
return "ace_math_rm_pipeline"
139163

140164

141165
class WorkerExtension:
@@ -309,6 +333,7 @@ def setup_vllm_worker(self, ray_actor_name, vllm_engine_args, gangs: Gangs):
309333
).remote(
310334
model=vllm_engine_args.model,
311335
tokenizer=vllm_engine_args.tokenizer,
336+
max_model_len=vllm_engine_args.max_model_len,
312337
enforce_eager=vllm_engine_args.enforce_eager,
313338
worker_extension_cls="fairseq2.recipes.lm._online_finetune._remote_model.WorkerExtension",
314339
tensor_parallel_size=vllm_engine_args.tensor_parallel_size,
@@ -437,6 +462,8 @@ def reward_from_model(self, prompt_list, batch_size=64):
437462
ray_outputs = ray.get(outputs)
438463
ray_outputs_flat = [o for sublist in ray_outputs for o in sublist]
439464
rewards = [o.outputs.data.item() for o in ray_outputs_flat]
465+
466+
log.info(f"Rewards = {rewards}")
440467

441468
return rewards
442469

@@ -537,7 +564,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None, string_input=Fal
537564
"RemoteHFModel.rollout_from_model is not implemented. "
538565
)
539566

540-
def reward_from_model(self, prompt_list, batch_size=64):
567+
def reward_from_model(self, prompt_list, batch_size=4):
541568
# NOTE: need to batch inputs to hf.encode model for current models that aren't supported by hf
542569
rewards = []
543570
outputs = []

src/fairseq2/recipes/lm/_online_finetune/_rewards.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,187 @@ def prepare_preference_batch(
461461
)
462462

463463
return batch, is_bad_batch, reward_output
464+
465+
class AceMathVerifierHandler(VLLMOutputRewardHandler):
466+
def __init__(self):
467+
pass
468+
469+
@override
470+
def create(self, reward_model, reward_name, reward_config, gangs, context):
471+
if reward_config.tokenizer is not None:
472+
tokenizer = reward_config.tokenizer
473+
else:
474+
tokenizer = "nvidia/AceMath-7B-RM"
475+
476+
return AceMathVerifier(
477+
gangs,
478+
context,
479+
reward_model,
480+
reward_name=reward_name,
481+
answer_key=reward_config.answer_key,
482+
prompt_key=reward_config.prompt_key,
483+
tokenizer=tokenizer,
484+
)
485+
486+
@property
487+
@override
488+
def name(self):
489+
return "acemath_verifier"
490+
491+
@property
492+
@override
493+
def config_kls(self):
494+
return None
495+
496+
class AceMathVerifier(VLLMOutputReward):
497+
def __init__(
498+
self,
499+
gangs,
500+
context,
501+
reward_model,
502+
reward_name,
503+
answer_key,
504+
prompt_key,
505+
tokenizer,
506+
):
507+
self.answer_key = answer_key
508+
self.prompt_key = prompt_key
509+
self._gangs = gangs
510+
self._context = context
511+
self.reward_model = reward_model
512+
self.reward_name = reward_name
513+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
514+
515+
def wrap_text(self, prompt_text, rollout_text):
516+
wrapped_text = [
517+
{"role": "system", "content": "Please reason step by step, and check your final answer within \\boxed{}."},
518+
{"role": "user", "content": prompt_text},
519+
{"role": "assistant", "content": rollout_text}
520+
]
521+
chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False, add_generation_prompt=False)
522+
if self.tokenizer.bos_token is not None and chat_str.startswith(
523+
self.tokenizer.bos_token
524+
):
525+
chat_str = chat_str[len(self.tokenizer.bos_token) :]
526+
527+
return chat_str
528+
529+
@override
530+
def process_rollouts(
531+
self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch
532+
):
533+
vllm_inputs = []
534+
batch_text = []
535+
batch_tokens = []
536+
537+
if vllm_outputs is None:
538+
vllm_outputs = [None] * len(prompt_batch.prompts)
539+
540+
text_prompts = prompt_batch.meta_info.get(self.prompt_key)
541+
for i, (i_batch_request_output, prompt_text) in enumerate(
542+
zip(vllm_outputs, text_prompts)
543+
):
544+
545+
rollouts_text = []
546+
rollouts_tokens = []
547+
for rollout_output in i_batch_request_output.outputs:
548+
rollout_text = rollout_output.text
549+
vllm_input = self.wrap_text(prompt_text, rollout_text)
550+
vllm_inputs.append(vllm_input)
551+
rollouts_text.append(rollout_output.text)
552+
rollouts_tokens.append(rollout_output.token_ids)
553+
554+
batch_text.append(rollouts_text)
555+
batch_tokens.append(rollouts_tokens)
556+
557+
batch_rewards = generate_rewards(
558+
vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model
559+
)
560+
561+
log.info(f"Batch rewards = {batch_rewards}")
562+
563+
# reshape batch_rewards to [Batch, Rollouts]
564+
B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts
565+
batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)]
566+
567+
return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards}
568+
569+
def prepare_preference_batch(
570+
self, prompt_batch: PromptBatch, rollouts
571+
) -> PreferenceBatch:
572+
573+
reward_output = self.process_rollouts(rollouts, prompt_batch)
574+
575+
chosen_batch = []
576+
rejected_batch = []
577+
prompt_lens = []
578+
dummy_batch_ids = [] # keep posiitons of dummy pairs here
579+
580+
# choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly)
581+
for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate(
582+
zip(reward_output["rewards"], reward_output["tokens"])
583+
):
584+
chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards))
585+
rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards))
586+
587+
if chosen_rollout_position == rejected_rollout_position:
588+
# cant form preference pair when we dont have such rollouts
589+
# this will be dummy batch and we zero out loss
590+
dummy_batch_ids.append(i_batch)
591+
592+
chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position])
593+
rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position])
594+
prompt_tokens = prompt_batch.prompts[i_batch]
595+
596+
chosen_tokens = prompt_tokens + chosen_rollout_tokens
597+
chosen_batch.append(chosen_tokens)
598+
599+
rejected_tokens = prompt_tokens + rejected_rollout_tokens
600+
rejected_batch.append(rejected_tokens)
601+
602+
prompt_lens.append(len(prompt_tokens))
603+
604+
filter_batch = lambda batch: [
605+
item for index, item in enumerate(batch) if index not in dummy_batch_ids
606+
]
607+
608+
if len(dummy_batch_ids) == len(reward_output["tokens"]):
609+
# entire batch does not have a valid preference pair
610+
# we use it as dummy batch and zero the loss in the end
611+
is_bad_batch = True
612+
else:
613+
# removing dummy pairs from the batch
614+
chosen_batch = filter_batch(chosen_batch)
615+
rejected_batch = filter_batch(rejected_batch)
616+
prompt_lens = filter_batch(prompt_lens)
617+
is_bad_batch = False
618+
619+
prompt_lens = torch.tensor(prompt_lens)
620+
621+
chosen_batch = [
622+
torch.tensor(sequence, device=self._gangs.dp.device)
623+
for sequence in chosen_batch
624+
]
625+
chosen_batch = collate_with_target_mask(
626+
chosen_batch, prompt_lens, device=self._gangs.dp.device
627+
)
628+
629+
rejected_batch = [
630+
torch.tensor(sequence, device=self._gangs.dp.device)
631+
for sequence in rejected_batch
632+
]
633+
rejected_batch = collate_with_target_mask(
634+
rejected_batch, prompt_lens, device=self._gangs.dp.device
635+
)
636+
637+
batch = PreferenceBatch(
638+
chosen=chosen_batch,
639+
rejected=rejected_batch,
640+
reference_score_chosen=None,
641+
reference_score_rejected=None,
642+
)
643+
644+
return batch, is_bad_batch, reward_output
464645

465646

466647
class AtheneVerifierHandler(VLLMOutputRewardHandler):
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
from fairseq2.logging import log
3+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
4+
5+
class AceMathRMPipeline:
6+
def __init__(self, *args, **kwargs):
7+
model_path = "/datasets/pretrained-llms/AceMath-7B-RM"
8+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
9+
self.model = AutoModelForSequenceClassification.from_pretrained(
10+
model_path, num_labels=1, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map = "auto"
11+
).eval()
12+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
13+
14+
def __call__(self, prompt_chunk):
15+
inputs = self.tokenizer(
16+
prompt_chunk,
17+
return_tensors="pt",
18+
padding=True,
19+
add_special_tokens=False
20+
).to(self.model.device)
21+
22+
outputs = self.model(**inputs)[0]
23+
log.info(f"outputs = {outputs}")
24+
rewards =[output[0] for output in outputs]
25+
26+
log.info(f"Length of rewards = {len(rewards)}")
27+
28+
return rewards

src/fairseq2/setup/_po_finetune_units.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from fairseq2.recipes.lm import ( # GroupDpoFinetuneUnitHandler,
1313
AtheneVerifierHandler,
1414
SkyworkVerifierHandler,
15+
AceMathVerifierHandler,
1516
CpoFinetuneUnitHandler,
1617
DpoFinetuneUnitHandler,
1718
GeneralVerifierExtractorHandler,
@@ -27,6 +28,7 @@
2728
MathVerifyHandler,
2829
NoEnvAtheneRewardPipeline,
2930
NoEnvGeneralVerifierPipeline,
31+
NoEnvAceMathRMPipeline,
3032
OnlineDpoFinetuneUnitHandler,
3133
OnlineFinetuneUnitHandler,
3234
OrpoFinetuneUnitHandler,
@@ -93,6 +95,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None:
9395
# SkyworkVerifier
9496
handler = SkyworkVerifierHandler()
9597
registry.register(handler.name, handler)
98+
99+
# AceMath RM
100+
handler = AceMathVerifierHandler()
101+
registry.register(handler.name, handler)
96102

97103
# AtheneVerifier
98104
handler = AtheneVerifierHandler()
@@ -123,6 +129,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None:
123129
# NoEnvGeneralVerifierPipeline
124130
handler = NoEnvGeneralVerifierPipeline
125131
registry.register(handler.name, handler)
132+
133+
# NoEnvAceMathRMPipeline
134+
handler = NoEnvAceMathRMPipeline
135+
registry.register(handler.name, handler)
126136

127137
# Generative judgment extractors
128138
registry = context.get_registry(JudgmentExtractorHandler)

0 commit comments

Comments
 (0)