@@ -461,6 +461,187 @@ def prepare_preference_batch(
461461 )
462462
463463 return batch , is_bad_batch , reward_output
464+
465+ class AceMathVerifierHandler (VLLMOutputRewardHandler ):
466+ def __init__ (self ):
467+ pass
468+
469+ @override
470+ def create (self , reward_model , reward_name , reward_config , gangs , context ):
471+ if reward_config .tokenizer is not None :
472+ tokenizer = reward_config .tokenizer
473+ else :
474+ tokenizer = "nvidia/AceMath-7B-RM"
475+
476+ return AceMathVerifier (
477+ gangs ,
478+ context ,
479+ reward_model ,
480+ reward_name = reward_name ,
481+ answer_key = reward_config .answer_key ,
482+ prompt_key = reward_config .prompt_key ,
483+ tokenizer = tokenizer ,
484+ )
485+
486+ @property
487+ @override
488+ def name (self ):
489+ return "acemath_verifier"
490+
491+ @property
492+ @override
493+ def config_kls (self ):
494+ return None
495+
496+ class AceMathVerifier (VLLMOutputReward ):
497+ def __init__ (
498+ self ,
499+ gangs ,
500+ context ,
501+ reward_model ,
502+ reward_name ,
503+ answer_key ,
504+ prompt_key ,
505+ tokenizer ,
506+ ):
507+ self .answer_key = answer_key
508+ self .prompt_key = prompt_key
509+ self ._gangs = gangs
510+ self ._context = context
511+ self .reward_model = reward_model
512+ self .reward_name = reward_name
513+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer )
514+
515+ def wrap_text (self , prompt_text , rollout_text ):
516+ wrapped_text = [
517+ {"role" : "system" , "content" : "Please reason step by step, and check your final answer within \\ boxed{}." },
518+ {"role" : "user" , "content" : prompt_text },
519+ {"role" : "assistant" , "content" : rollout_text }
520+ ]
521+ chat_str = self .tokenizer .apply_chat_template (wrapped_text , tokenize = False , add_generation_prompt = False )
522+ if self .tokenizer .bos_token is not None and chat_str .startswith (
523+ self .tokenizer .bos_token
524+ ):
525+ chat_str = chat_str [len (self .tokenizer .bos_token ) :]
526+
527+ return chat_str
528+
529+ @override
530+ def process_rollouts (
531+ self , vllm_outputs : list [RequestOutput ], prompt_batch : PromptBatch
532+ ):
533+ vllm_inputs = []
534+ batch_text = []
535+ batch_tokens = []
536+
537+ if vllm_outputs is None :
538+ vllm_outputs = [None ] * len (prompt_batch .prompts )
539+
540+ text_prompts = prompt_batch .meta_info .get (self .prompt_key )
541+ for i , (i_batch_request_output , prompt_text ) in enumerate (
542+ zip (vllm_outputs , text_prompts )
543+ ):
544+
545+ rollouts_text = []
546+ rollouts_tokens = []
547+ for rollout_output in i_batch_request_output .outputs :
548+ rollout_text = rollout_output .text
549+ vllm_input = self .wrap_text (prompt_text , rollout_text )
550+ vllm_inputs .append (vllm_input )
551+ rollouts_text .append (rollout_output .text )
552+ rollouts_tokens .append (rollout_output .token_ids )
553+
554+ batch_text .append (rollouts_text )
555+ batch_tokens .append (rollouts_tokens )
556+
557+ batch_rewards = generate_rewards (
558+ vllm_inputs , dp_gang = self ._gangs .dp , vllm_model = self .reward_model
559+ )
560+
561+ log .info (f"Batch rewards = { batch_rewards } " )
562+
563+ # reshape batch_rewards to [Batch, Rollouts]
564+ B , R = len (batch_text ), len (batch_text [0 ]) # batch size, rollouts
565+ batch_rewards = [batch_rewards [i * R : (i + 1 ) * R ] for i in range (B )]
566+
567+ return {"text" : batch_text , "tokens" : batch_tokens , "rewards" : batch_rewards }
568+
569+ def prepare_preference_batch (
570+ self , prompt_batch : PromptBatch , rollouts
571+ ) -> PreferenceBatch :
572+
573+ reward_output = self .process_rollouts (rollouts , prompt_batch )
574+
575+ chosen_batch = []
576+ rejected_batch = []
577+ prompt_lens = []
578+ dummy_batch_ids = [] # keep posiitons of dummy pairs here
579+
580+ # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly)
581+ for i_batch , (i_batch_rewards , i_batch_tokens ) in enumerate (
582+ zip (reward_output ["rewards" ], reward_output ["tokens" ])
583+ ):
584+ chosen_rollout_position = i_batch_rewards .index (max (i_batch_rewards ))
585+ rejected_rollout_position = i_batch_rewards .index (min (i_batch_rewards ))
586+
587+ if chosen_rollout_position == rejected_rollout_position :
588+ # cant form preference pair when we dont have such rollouts
589+ # this will be dummy batch and we zero out loss
590+ dummy_batch_ids .append (i_batch )
591+
592+ chosen_rollout_tokens = list (i_batch_tokens [chosen_rollout_position ])
593+ rejected_rollout_tokens = list (i_batch_tokens [rejected_rollout_position ])
594+ prompt_tokens = prompt_batch .prompts [i_batch ]
595+
596+ chosen_tokens = prompt_tokens + chosen_rollout_tokens
597+ chosen_batch .append (chosen_tokens )
598+
599+ rejected_tokens = prompt_tokens + rejected_rollout_tokens
600+ rejected_batch .append (rejected_tokens )
601+
602+ prompt_lens .append (len (prompt_tokens ))
603+
604+ filter_batch = lambda batch : [
605+ item for index , item in enumerate (batch ) if index not in dummy_batch_ids
606+ ]
607+
608+ if len (dummy_batch_ids ) == len (reward_output ["tokens" ]):
609+ # entire batch does not have a valid preference pair
610+ # we use it as dummy batch and zero the loss in the end
611+ is_bad_batch = True
612+ else :
613+ # removing dummy pairs from the batch
614+ chosen_batch = filter_batch (chosen_batch )
615+ rejected_batch = filter_batch (rejected_batch )
616+ prompt_lens = filter_batch (prompt_lens )
617+ is_bad_batch = False
618+
619+ prompt_lens = torch .tensor (prompt_lens )
620+
621+ chosen_batch = [
622+ torch .tensor (sequence , device = self ._gangs .dp .device )
623+ for sequence in chosen_batch
624+ ]
625+ chosen_batch = collate_with_target_mask (
626+ chosen_batch , prompt_lens , device = self ._gangs .dp .device
627+ )
628+
629+ rejected_batch = [
630+ torch .tensor (sequence , device = self ._gangs .dp .device )
631+ for sequence in rejected_batch
632+ ]
633+ rejected_batch = collate_with_target_mask (
634+ rejected_batch , prompt_lens , device = self ._gangs .dp .device
635+ )
636+
637+ batch = PreferenceBatch (
638+ chosen = chosen_batch ,
639+ rejected = rejected_batch ,
640+ reference_score_chosen = None ,
641+ reference_score_rejected = None ,
642+ )
643+
644+ return batch , is_bad_batch , reward_output
464645
465646
466647class AtheneVerifierHandler (VLLMOutputRewardHandler ):
0 commit comments