From f2be85f809111555ad6fc584dbae6845999c96ac Mon Sep 17 00:00:00 2001 From: Hui-design <2225705604@qq.com> Date: Sun, 23 Feb 2025 17:22:24 +0800 Subject: [PATCH] tch fix_bug --- src/open_r1/trainer/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_r1/trainer/grpo_trainer.py b/src/open_r1/trainer/grpo_trainer.py index b797cb8..5bf2380 100644 --- a/src/open_r1/trainer/grpo_trainer.py +++ b/src/open_r1/trainer/grpo_trainer.py @@ -415,10 +415,10 @@ def get_per_token_logps(model, input_ids, **kwargs): with torch.inference_mode(): if self.ref_model is not None: - ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids) + ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs) # tch fix_bug else: with self.accelerator.unwrap_model(model).disable_adapter(): - ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) + ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs) # tch fix_bug ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] # Compute the KL divergence between the model and the reference model