Skip to content

Commit 1162d60

Browse files
author
swarna
committed
Skywork-RM from hf
1 parent d474070 commit 1162d60

File tree

7 files changed

+123
-39
lines changed

7 files changed

+123
-39
lines changed

src/fairseq2/recipes/lm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@
9393
from fairseq2.recipes.lm._online_finetune._remote_model import (
9494
NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline,
9595
)
96+
from fairseq2.recipes.lm._online_finetune._remote_model import (
97+
NoEnvSkyworkRMPipeline as NoEnvSkyworkRMPipeline,
98+
)
9699
from fairseq2.recipes.lm._online_finetune._remote_model import (
97100
RemoteModelHandler as RemoteModelHandler,
98101
)

src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,18 @@
213213

214214
class JudgmentExtractorHandler(ABC):
215215
@abstractmethod
216-
def create(self, tokenizer): ...
216+
def create(self, tokenizer):
217+
...
217218

218219
@property
219220
@abstractmethod
220-
def name(self) -> str: ...
221+
def name(self) -> str:
222+
...
221223

222224
@property
223225
@abstractmethod
224-
def config_kls(self) -> type[object]: ...
226+
def config_kls(self) -> type[object]:
227+
...
225228

226229

227230
"""
@@ -240,10 +243,12 @@ class JudgmentExtractor(ABC):
240243
"""
241244

242245
@abstractmethod
243-
def prompt(self) -> str: ...
246+
def prompt(self) -> str:
247+
...
244248

245249
@abstractmethod
246-
def format_prompt(self, prompt_text, **kwargs: Any) -> str: ...
250+
def format_prompt(self, prompt_text, **kwargs: Any) -> str:
251+
...
247252

248253
"""
249254
Format the prompt text and additional arguments into a string suitable for input to the reward model.
@@ -256,7 +261,8 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str: ...
256261
"""
257262

258263
@abstractmethod
259-
def extract(self, generation) -> float | str: ...
264+
def extract(self, generation) -> float | str:
265+
...
260266

261267
"""
262268
Extract the final scalar reward score from the model's response.
@@ -275,7 +281,8 @@ def extract(self, generation) -> float | str: ...
275281
"""
276282

277283
@abstractmethod
278-
def aggregate(self, judgments) -> float | str: ...
284+
def aggregate(self, judgments) -> float | str:
285+
...
279286

280287
"""
281288
Aggregate multiple responses (judgments) from the reward model into a single value.

src/fairseq2/recipes/lm/_online_finetune/_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ class OnlineFinetuneUnitHandler(ABC):
1919
@abstractmethod
2020
def create(
2121
self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object
22-
) -> TrainUnit[SequenceBatch]: ...
22+
) -> TrainUnit[SequenceBatch]:
23+
...
2324

2425
@property
2526
@abstractmethod
26-
def name(self) -> str: ...
27+
def name(self) -> str:
28+
...
2729

2830
@property
2931
@abstractmethod
30-
def config_kls(self) -> type[object]: ...
32+
def config_kls(self) -> type[object]:
33+
...
3134

3235

3336
class UnknownOnlineFinetuneUnitError(Exception):

src/fairseq2/recipes/lm/_online_finetune/_remote_model.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
from fairseq2.gang import Gangs
2828
from fairseq2.logging import log
2929
from fairseq2.nn._batch_layout import BatchLayout
30+
from fairseq2.recipes.lm._online_finetune.third_party.ace_math import AceMathRMPipeline
3031
from fairseq2.recipes.lm._online_finetune.third_party.athene import AtheneRewardPipeline
3132
from fairseq2.recipes.lm._online_finetune.third_party.general_verifier import (
3233
GeneralVerifierPipeline,
3334
)
34-
from fairseq2.recipes.lm._online_finetune.third_party.ace_math import (
35-
AceMathRMPipeline,
36-
)
35+
from fairseq2.recipes.lm._online_finetune.third_party.skywork import SkyworkRMPipeline
3736
from fairseq2.utils.structured import StructureError, structure
3837

3938

@@ -140,7 +139,8 @@ def is_ready(self):
140139
@property
141140
def name(self):
142141
return "general_verifier_pipeline"
143-
142+
143+
144144
@ray.remote
145145
class NoEnvAceMathRMPipeline(AceMathRMPipeline):
146146
"""
@@ -162,6 +162,27 @@ def name(self):
162162
return "ace_math_rm_pipeline"
163163

164164

165+
@ray.remote
166+
class NoEnvSkyworkRMPipeline(SkyworkRMPipeline):
167+
"""
168+
This is for running Ace Math RM pipeline with HF backend.
169+
"""
170+
171+
def __init__(self, *args, **kwargs):
172+
# stop ray from manipulating CUDA_VISIBLE_DEVICES
173+
# at the top-level
174+
del os.environ["CUDA_VISIBLE_DEVICES"]
175+
super().__init__(*args, **kwargs)
176+
self.ready = True # Set a flag or return a signal
177+
178+
def is_ready(self):
179+
return self.ready
180+
181+
@property
182+
def name(self):
183+
return "skywork_rm_pipeline"
184+
185+
165186
class WorkerExtension:
166187
"""
167188
The class for vLLM's worker to inherit from.
@@ -462,7 +483,7 @@ def reward_from_model(self, prompt_list, batch_size=64):
462483
ray_outputs = ray.get(outputs)
463484
ray_outputs_flat = [o for sublist in ray_outputs for o in sublist]
464485
rewards = [o.outputs.data.item() for o in ray_outputs_flat]
465-
486+
466487
log.info(f"Rewards = {rewards}")
467488

468489
return rewards
@@ -591,15 +612,18 @@ class RemoteModelHandler(ABC):
591612
@abstractmethod
592613
def create(
593614
self, gangs: Gangs, unit_config: object
594-
) -> Union[RemoteVllmModel, RemoteHFModel]: ...
615+
) -> Union[RemoteVllmModel, RemoteHFModel]:
616+
...
595617

596618
@property
597619
@abstractmethod
598-
def name(self) -> str: ...
620+
def name(self) -> str:
621+
...
599622

600623
@property
601624
@abstractmethod
602-
def config_kls(self) -> type[object]: ...
625+
def config_kls(self) -> type[object]:
626+
...
603627

604628

605629
class RemoteRayModelHandler(RemoteModelHandler):

src/fairseq2/recipes/lm/_online_finetune/_rewards.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,28 @@ class VLLMOutputRewardHandler(ABC):
5757
@abstractmethod
5858
def create(
5959
self, reward_model: Any, gangs: Gangs, reward_config: object
60-
) -> VLLMOutputReward: ...
60+
) -> VLLMOutputReward:
61+
...
6162

6263
@property
6364
@abstractmethod
64-
def name(self) -> str: ...
65+
def name(self) -> str:
66+
...
6567

6668
@property
6769
@abstractmethod
68-
def config_kls(self) -> type[object]: ...
70+
def config_kls(self) -> type[object]:
71+
...
6972

7073

7174
class VLLMOutputReward(ABC):
7275
@abstractmethod
73-
def process_rollouts(self, vllm_outputs: list[RequestOutput]): ...
76+
def process_rollouts(self, vllm_outputs: list[RequestOutput]):
77+
...
7478

7579
@abstractmethod
76-
def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): ...
80+
def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts):
81+
...
7782

7883

7984
class GSM8kVerifierHandler(VLLMOutputRewardHandler):
@@ -461,7 +466,8 @@ def prepare_preference_batch(
461466
)
462467

463468
return batch, is_bad_batch, reward_output
464-
469+
470+
465471
class AceMathVerifierHandler(VLLMOutputRewardHandler):
466472
def __init__(self):
467473
pass
@@ -492,7 +498,8 @@ def name(self):
492498
@override
493499
def config_kls(self):
494500
return None
495-
501+
502+
496503
class AceMathVerifier(VLLMOutputReward):
497504
def __init__(
498505
self,
@@ -514,11 +521,16 @@ def __init__(
514521

515522
def wrap_text(self, prompt_text, rollout_text):
516523
wrapped_text = [
517-
{"role": "system", "content": "Please reason step by step, and check your final answer within \\boxed{}."},
524+
{
525+
"role": "system",
526+
"content": "Please reason step by step, and check your final answer within \\boxed{}.",
527+
},
518528
{"role": "user", "content": prompt_text},
519-
{"role": "assistant", "content": rollout_text}
529+
{"role": "assistant", "content": rollout_text},
520530
]
521-
chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False, add_generation_prompt=False)
531+
chat_str = self.tokenizer.apply_chat_template(
532+
wrapped_text, tokenize=False, add_generation_prompt=False
533+
)
522534
if self.tokenizer.bos_token is not None and chat_str.startswith(
523535
self.tokenizer.bos_token
524536
):
@@ -557,7 +569,7 @@ def process_rollouts(
557569
batch_rewards = generate_rewards(
558570
vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model
559571
)
560-
572+
561573
log.info(f"Batch rewards = {batch_rewards}")
562574

563575
# reshape batch_rewards to [Batch, Rollouts]
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
11
import torch
2-
from fairseq2.logging import log
32
from transformers import AutoModelForSequenceClassification, AutoTokenizer
43

4+
from fairseq2.logging import log
5+
6+
57
class AceMathRMPipeline:
68
def __init__(self, *args, **kwargs):
79
model_path = "/datasets/pretrained-llms/AceMath-7B-RM"
8-
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10+
self.tokenizer = AutoTokenizer.from_pretrained(
11+
model_path, trust_remote_code=True
12+
)
913
self.model = AutoModelForSequenceClassification.from_pretrained(
10-
model_path, num_labels=1, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map = "auto"
14+
model_path,
15+
num_labels=1,
16+
torch_dtype=torch.bfloat16,
17+
trust_remote_code=True,
18+
device_map="auto",
1119
).eval()
1220
self.model.config.pad_token_id = self.tokenizer.pad_token_id
1321

1422
def __call__(self, prompt_chunk):
1523
inputs = self.tokenizer(
16-
prompt_chunk,
17-
return_tensors="pt",
18-
padding=True,
19-
add_special_tokens=False
24+
prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False
2025
).to(self.model.device)
21-
26+
2227
outputs = self.model(**inputs)[0]
2328
log.info(f"outputs = {outputs}")
24-
rewards =[output[0] for output in outputs]
25-
29+
rewards = [output[0] for output in outputs]
30+
2631
log.info(f"Length of rewards = {len(rewards)}")
2732

2833
return rewards
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
3+
4+
from fairseq2.logging import log
5+
6+
7+
class SkyworkRMPipeline:
8+
def __init__(self, *args, **kwargs):
9+
model_path = "/datasets/pretrained-llms/Skywork-Reward-V2-Llama-3.1-8B"
10+
self.tokenizer = AutoTokenizer.from_pretrained(
11+
model_path, trust_remote_code=True
12+
)
13+
self.model = AutoModelForSequenceClassification.from_pretrained(
14+
model_path,
15+
num_labels=1,
16+
torch_dtype=torch.bfloat16,
17+
trust_remote_code=True,
18+
device_map="auto",
19+
).eval()
20+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
21+
22+
def __call__(self, prompt_chunk):
23+
inputs = self.tokenizer(
24+
prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False
25+
).to(self.model.device)
26+
27+
outputs = self.model(**inputs)[0]
28+
rewards = [output[0] for output in outputs]
29+
30+
return rewards

0 commit comments

Comments
 (0)