|
7 | 7 | import re
|
8 | 8 | from abc import ABC, abstractmethod
|
9 | 9 | from dataclasses import dataclass, field
|
10 |
| -from typing import Optional |
| 10 | +from typing import Optional, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
| 14 | +from torchtune.modules.transforms.tokenizers import ( |
| 15 | + HuggingFaceModelTokenizer, |
| 16 | + ModelTokenizer, |
| 17 | +) |
| 18 | + |
14 | 19 |
|
15 | 20 | @dataclass
|
16 | 21 | class RewardOutput:
|
@@ -216,3 +221,96 @@ def __call__(
|
216 | 221 | },
|
217 | 222 | successes=successes,
|
218 | 223 | )
|
| 224 | + |
| 225 | + |
| 226 | +def at_least_one_space_between_think_tags( |
| 227 | + cot: str, answer: str, potential_answer: str |
| 228 | +) -> tuple[float, float]: |
| 229 | + """Did the model at least try to think?""" |
| 230 | + if len(cot) > 0: |
| 231 | + return 1.0, 1.0 # (reward, success) |
| 232 | + else: |
| 233 | + return 0.0, 0.0 |
| 234 | + |
| 235 | + |
| 236 | +def math_response_correct( |
| 237 | + cot: str, answer: str, potential_answer: str |
| 238 | +) -> tuple[float, float]: |
| 239 | + """Did it get the right answer?""" |
| 240 | + import math_verify |
| 241 | + |
| 242 | + if potential_answer is None: |
| 243 | + return 0.0, 0.0 # (reward, success) |
| 244 | + gold = math_verify.parse(answer) |
| 245 | + attempt = math_verify.parse(potential_answer) |
| 246 | + |
| 247 | + if math_verify.verify(gold, attempt): |
| 248 | + return 100.0, 1.0 |
| 249 | + if answer in potential_answer: |
| 250 | + return 50.0, 0.0 |
| 251 | + if len(potential_answer) > 0: |
| 252 | + return 1.0, 0.0 |
| 253 | + return 0.0, 0.0 |
| 254 | + |
| 255 | + |
| 256 | +def extract_tags(text: str) -> tuple[str, str]: |
| 257 | + """ |
| 258 | + Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'. |
| 259 | + The values are lists of strings, with each string being the content of a tag. |
| 260 | + """ |
| 261 | + think_pattern = r"<think>(.*?)</think>" |
| 262 | + answer_pattern = r"<answer>(.*?)</answer>" |
| 263 | + think_match = re.search(think_pattern, text, re.DOTALL) |
| 264 | + answer_match = re.search(answer_pattern, text, re.DOTALL) |
| 265 | + cot = think_match.group(1).strip() if think_match else "" |
| 266 | + potential_answer = answer_match.group(1).strip() if answer_match else "" |
| 267 | + return cot, potential_answer |
| 268 | + |
| 269 | + |
| 270 | +def batched_rewards( |
| 271 | + tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer], |
| 272 | + completions: torch.Tensor, |
| 273 | + answers: list[str], |
| 274 | + device: torch.device, |
| 275 | +) -> tuple[torch.Tensor, torch.Tensor, dict]: |
| 276 | + |
| 277 | + reward_funcs = [ |
| 278 | + at_least_one_space_between_think_tags, |
| 279 | + math_response_correct, |
| 280 | + ] |
| 281 | + |
| 282 | + num_reward_funcs = len(reward_funcs) |
| 283 | + |
| 284 | + batch_size, grpo_size, _ = completions.shape |
| 285 | + |
| 286 | + # TODO: should this be bfloat16? |
| 287 | + |
| 288 | + rewards_tensor = torch.zeros( |
| 289 | + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device |
| 290 | + ) |
| 291 | + |
| 292 | + successes_tensor = torch.zeros( |
| 293 | + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device |
| 294 | + ) |
| 295 | + |
| 296 | + metadata = {"func_names": [f.__name__ for f in reward_funcs]} |
| 297 | + |
| 298 | + for b in range(batch_size): |
| 299 | + |
| 300 | + for g in range(grpo_size): |
| 301 | + |
| 302 | + answer = answers[b][g] |
| 303 | + |
| 304 | + text_completion = tokenizer.decode(completions[b, g].tolist()) |
| 305 | + |
| 306 | + cot, potential_answer = extract_tags(f"<think>{text_completion}") |
| 307 | + |
| 308 | + for rw_idx, reward_func in enumerate(reward_funcs): |
| 309 | + |
| 310 | + reward, success = reward_func(cot, answer, potential_answer) |
| 311 | + |
| 312 | + rewards_tensor[b, g, rw_idx] += reward |
| 313 | + |
| 314 | + successes_tensor[b, g, rw_idx] += success |
| 315 | + |
| 316 | + return rewards_tensor, successes_tensor, metadata |
0 commit comments