Skip to content

Commit d77d409

Browse files
Birdylxliuxiao.28
andauthored
[bugfix] stop gradient for chord phi function (#6952)
Co-authored-by: liuxiao.28 <[email protected]>
1 parent 913883f commit d77d409

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/rlhf_trainer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor:
10171017
chord_sft_loss = per_token_loss_func(outputs, labels)
10181018

10191019
if trainer.args.chord_enable_phi_function:
1020-
per_token_probs = torch.exp(-chord_sft_loss)
1020+
per_token_probs = torch.exp(-chord_sft_loss.detach())
10211021
phi = per_token_probs * (1 - per_token_probs)
10221022
chord_sft_loss *= phi
10231023

0 commit comments

Comments
 (0)