Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/eval/tofu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
defaults: # include all defined metrics files
- tofu_metrics: # When you import a metric here, its configuration automatically populates the
# metric key below, enabled by the @package directive at the top of each configuration file.
- forget_Truth_Ratio
- forget_quality
- forget_Q_A_Prob
- forget_Q_A_ROUGE
Expand Down
14 changes: 13 additions & 1 deletion src/evals/metrics/memorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ def closer_to_1_better(arr):
def true_better(arr):
return np.mean(np.maximum(0, 1 - arr))

# Extent of knowledge (as used in OpenUnlearning paper's meta-evaluation) uses tr=true/(true+false)
def prob_mean(arr):
return np.mean(arr)

if kwargs["aggregator"] == "closer_to_1_better":
aggregator = closer_to_1_better
elif kwargs["aggregator"] == "true_better":
aggregator = true_better
elif kwargs["aggregator"] == "prob_mean":
aggregator = prob_mean
else:
raise ValueError(f"Invalid truth ratio aggregator: {kwargs['aggregator']}")

Expand Down Expand Up @@ -153,7 +159,13 @@ def true_better(arr):
correct_prob = np.exp(-correct_avg_losses)
wrong_prob = np.exp(-wrong_avg_losses)

truth_ratios = wrong_prob / (correct_prob + 1e-10)
if kwargs["aggregator"] != "prob_mean":
# Original definition from TOFU: wrong / correct
truth_ratios = wrong_prob / (correct_prob + 1e-10)
else:
# New definition from OpenUnlearning: correct / (correct + wrong)
truth_ratios = correct_prob / (correct_prob + wrong_prob + 1e-10)

value_by_index = dict(
zip(correct_indices, [{"score": val} for val in truth_ratios])
)
Expand Down