Skip to content

Commit 8a507ad

Browse files
authored
Fix grpo recipe (#2893)
1 parent 4249bd2 commit 8a507ad

File tree

1 file changed

+99
-1
lines changed

1 file changed

+99
-1
lines changed

torchtune/dev/rl/rewards.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
import re
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass, field
10-
from typing import Optional
10+
from typing import Optional, Union
1111

1212
import torch
1313

14+
from torchtune.modules.transforms.tokenizers import (
15+
HuggingFaceModelTokenizer,
16+
ModelTokenizer,
17+
)
18+
1419

1520
@dataclass
1621
class RewardOutput:
@@ -216,3 +221,96 @@ def __call__(
216221
},
217222
successes=successes,
218223
)
224+
225+
226+
def at_least_one_space_between_think_tags(
227+
cot: str, answer: str, potential_answer: str
228+
) -> tuple[float, float]:
229+
"""Did the model at least try to think?"""
230+
if len(cot) > 0:
231+
return 1.0, 1.0 # (reward, success)
232+
else:
233+
return 0.0, 0.0
234+
235+
236+
def math_response_correct(
237+
cot: str, answer: str, potential_answer: str
238+
) -> tuple[float, float]:
239+
"""Did it get the right answer?"""
240+
import math_verify
241+
242+
if potential_answer is None:
243+
return 0.0, 0.0 # (reward, success)
244+
gold = math_verify.parse(answer)
245+
attempt = math_verify.parse(potential_answer)
246+
247+
if math_verify.verify(gold, attempt):
248+
return 100.0, 1.0
249+
if answer in potential_answer:
250+
return 50.0, 0.0
251+
if len(potential_answer) > 0:
252+
return 1.0, 0.0
253+
return 0.0, 0.0
254+
255+
256+
def extract_tags(text: str) -> tuple[str, str]:
257+
"""
258+
Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'.
259+
The values are lists of strings, with each string being the content of a tag.
260+
"""
261+
think_pattern = r"<think>(.*?)</think>"
262+
answer_pattern = r"<answer>(.*?)</answer>"
263+
think_match = re.search(think_pattern, text, re.DOTALL)
264+
answer_match = re.search(answer_pattern, text, re.DOTALL)
265+
cot = think_match.group(1).strip() if think_match else ""
266+
potential_answer = answer_match.group(1).strip() if answer_match else ""
267+
return cot, potential_answer
268+
269+
270+
def batched_rewards(
271+
tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer],
272+
completions: torch.Tensor,
273+
answers: list[str],
274+
device: torch.device,
275+
) -> tuple[torch.Tensor, torch.Tensor, dict]:
276+
277+
reward_funcs = [
278+
at_least_one_space_between_think_tags,
279+
math_response_correct,
280+
]
281+
282+
num_reward_funcs = len(reward_funcs)
283+
284+
batch_size, grpo_size, _ = completions.shape
285+
286+
# TODO: should this be bfloat16?
287+
288+
rewards_tensor = torch.zeros(
289+
batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
290+
)
291+
292+
successes_tensor = torch.zeros(
293+
batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
294+
)
295+
296+
metadata = {"func_names": [f.__name__ for f in reward_funcs]}
297+
298+
for b in range(batch_size):
299+
300+
for g in range(grpo_size):
301+
302+
answer = answers[b][g]
303+
304+
text_completion = tokenizer.decode(completions[b, g].tolist())
305+
306+
cot, potential_answer = extract_tags(f"<think>{text_completion}")
307+
308+
for rw_idx, reward_func in enumerate(reward_funcs):
309+
310+
reward, success = reward_func(cot, answer, potential_answer)
311+
312+
rewards_tensor[b, g, rw_idx] += reward
313+
314+
successes_tensor[b, g, rw_idx] += success
315+
316+
return rewards_tensor, successes_tensor, metadata

0 commit comments

Comments
 (0)