-
Notifications
You must be signed in to change notification settings - Fork 122
PDU, Llama 2 13B, MUSE finetuning #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6c7b6d5
d32bc85
1dfbc22
f4d75e4
a641237
4c8314a
3b93aed
f9921fd
0742c88
945d1ec
be91d6a
c75264c
5deb6d1
13882a3
8e30ebe
729c877
2ffe9b1
ad1103b
8a75e88
68ebec9
e8b1870
17e8e73
7b9646a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Constrained Entropic Unlearning: A Primal-Dual Framework for Large Language Models | ||
|
|
||
|
|
||
| We propose a new formulation of LLM unlearning | ||
| as a constrained optimization problem: forgetting is enforced via a novel logitmargin flattening loss | ||
| that explicitly drives the output distribution toward uniformity on a designated forget set, | ||
| while retention is preserved through a hard constraint on a separate retain set. | ||
| We solve the constrained problem using a scalable primal-dual algorithm that exposes the | ||
| trade-off between forgetting and retention through the dynamics of the dual variable. | ||
|
|
||
| # Setup | ||
|
|
||
| Experimental setup | ||
|
|
||
| - **Hyperparameters & Search Space:** Please see the corresponding [paper](https://arxiv.org/abs/2506.05314) for details of the hyperparameter. Importantly | ||
| to produce good results using our method, it is vital the hyperparameter `retain_loss_eps` is set to an appropriate value. | ||
| To choose such a value, simply look at the value of the retain loss of the pretrained model and choose | ||
| an appropriately larger value than this starting value. | ||
|
|
||
| Note that our method's loss is a quadratic function of a difference in the logit spaces. Consequently, | ||
| the value of this loss can be large. As a result, it is natural that we set the initial parameter of the | ||
| retain loss preference to 50 or 100. | ||
| - **Computational Setup:** Please see the Supplementary Material in the paper. | ||
|
|
||
| # Results | ||
|
|
||
| Please see the `run.sh` script that contains all necessary commands to reproduce the final results. | ||
|
|
||
| All unlearned models are available under https://huggingface.co/tamarsonha. | ||
|
|
||
| # Citation | ||
|
|
||
|
|
||
| If you use this work, please cite: | ||
|
|
||
|
|
||
| ```bibtex | ||
|
|
||
|
|
||
| @article{entesari2025constrained, | ||
| title={Constrained Entropic Unlearning: A Primal-Dual Framework for Large Language Models}, | ||
| author={Entesari, Taha and Hatami, Arman and Khaziev, Rinat and Ramakrishna, Anil and Fazlyab, Mahyar}, | ||
| journal={arXiv preprint arXiv:2506.05314}, | ||
| year={2025} | ||
| } | ||
|
|
||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| #!/bin/bash | ||
|
|
||
|
|
||
|
|
||
| ######################################################################################################################## | ||
| ########################################### Final best parameters ##################################################### | ||
| ######################################################################################################################## | ||
| # for an 8 GPU system: | ||
| num_processes=8 | ||
|
|
||
|
|
||
| ############################################## TOFU ##################################################### | ||
| per_device_train_batch_size=4 | ||
| learning_rate=0.00001 | ||
| dual_warmup_epochs=5 | ||
|
|
||
| pref=100 | ||
| dual_step_size=5 | ||
| retain_loss_eps=0.3 | ||
|
|
||
| retain_precentages=(90 95 99) | ||
| models=(Llama-3.2-1B-Instruct Llama-3.2-3B-Instruct Llama-3.1-8B-Instruct gemma-7b-it) | ||
|
|
||
| for model in "${models[@]}"; do | ||
| for retain_percentage in "${retain_precentages[@]}"; do | ||
|
|
||
| if [ "$retain_percentage" = "90" ]; then | ||
| forget_split=forget10 | ||
| retain_split=retain90 | ||
| elif [ "$retain_percentage" = "95" ]; then | ||
| forget_split=forget05 | ||
| retain_split=retain95 | ||
| elif [ "$retain_percentage" = "99" ]; then | ||
| forget_split=forget01 | ||
| retain_split=retain99 | ||
| else | ||
| # echo "hello" | ||
| echo "Invalid retain percentage. Please set it to 90, 95, or 99." | ||
| exit 1 | ||
| fi | ||
|
|
||
|
|
||
| if [ "$model" = "Llama-3.2-1B-Instruct" ]; then | ||
| pretrained_model_name_or_path=open-unlearning/tofu_Llama-3.2-1B-Instruct_full | ||
| num_train_epochs=10 | ||
| elif [ "$model" = "Llama-3.2-3B-Instruct" ]; then | ||
| pretrained_model_name_or_path=open-unlearning/tofu_Llama-3.2-3B-Instruct_full | ||
| num_train_epochs=10 | ||
| elif [ "$model" = "Llama-3.1-8B-Instruct" ]; then | ||
| pretrained_model_name_or_path=open-unlearning/tofu_Llama-3.1-8B-Instruct_full | ||
| num_train_epochs=30 | ||
| elif [ "$model" = "gemma-7b-it" ]; then | ||
| pretrained_model_name_or_path=tamarsonha/TOFU-target-gemma-7b-it | ||
| num_train_epochs=20 | ||
| else | ||
| echo "Invalid model name. Please set it to Llama-3.2-1B-Instruct, Llama-3.2-3B-Instruct, Llama-3.1-8B-Instruct, or gemma-7b-it." | ||
| exit 1 | ||
| fi | ||
|
|
||
| task_name=PDU-TOFU$retain_split-E$num_train_epochs-lr$learning_rate-P1-$pref-Primal$retain_loss_eps-Step$dual_step_size-Warmup$dual_warmup_epochs-model_$model | ||
| accelerate launch --config_file configs/accelerate/default_config.yaml --num_processes=$num_processes \ | ||
| src/train.py --config-name=unlearn.yaml experiment=unlearn/tofu/default \ | ||
| forget_split=$forget_split retain_split=$retain_split\ | ||
| trainer=PDU\ | ||
| trainer.args.num_train_epochs=$num_train_epochs\ | ||
| trainer.args.eval_on_start=false trainer.args.do_eval=false\ | ||
| trainer.args.per_device_train_batch_size=$per_device_train_batch_size\ | ||
| trainer.args.learning_rate=$learning_rate\ | ||
| trainer.method_args.gamma=1. trainer.method_args.alpha=$pref\ | ||
| trainer.method_args.primal_dual=true trainer.method_args.retain_loss_eps=$retain_loss_eps\ | ||
| trainer.method_args.dual_step_size=$dual_step_size\ | ||
| trainer.method_args.dual_update_upon="step" trainer.method_args.dual_warmup_epochs=$dual_warmup_epochs\ | ||
| task_name=$task_name\ | ||
| model=$model model.model_args.pretrained_model_name_or_path=$pretrained_model_name_or_path | ||
| done | ||
| done | ||
|
|
||
|
|
||
| ######################################################## MUSE ######################################################### | ||
| dual_step_size=1 | ||
| num_train_epochs=10 | ||
| dual_warmup_epochs=3 | ||
| data_splits=("News" "Books") | ||
| learning_rate=0.00001 | ||
| dual_update_upon="step" | ||
|
|
||
| models=(Llama-2-7b-hf Llama-2-13b-hf) | ||
| pref=50 | ||
|
|
||
| for model in "${models[@]}"; do | ||
| for data_split in "${data_splits[@]}"; do | ||
|
|
||
| if [ "$model" = "Llama-2-7b-hf" ]; then | ||
| pretrained_model_name_or_path=muse-bench/MUSE-${data_split}_target | ||
| epsNews=(1.5) | ||
| epsBooks=(0.1) | ||
| elif [ "$model" = "Llama-2-13b-hf" ]; then | ||
| pretrained_model_name_or_path=tamarsonha/MUSE-${data_split}-target-Llama-2-13b-hf | ||
| epsNews=(0.8) | ||
| epsBooks=(0.6) | ||
| else | ||
| exit 1 | ||
| fi | ||
|
|
||
|
|
||
| if [ "$data_split" == "News" ]; then | ||
| eps_array=("${epsNews[@]}") | ||
| else | ||
| eps_array=("${epsBooks[@]}") | ||
| fi | ||
|
|
||
| for retain_loss_eps in "${eps_array[@]}"; do | ||
| task_name=PDU-Muse$data_split-E$num_train_epochs-lr$learning_rate-P1-$pref-Primal$retain_loss_eps-Step$dual_step_size-Warmup$dual_warmup_epochs-model$model | ||
| accelerate launch --config_file configs/accelerate/default_config.yaml --num_processes=$num_processes \ | ||
| src/train.py --config-name=unlearn.yaml experiment=unlearn/muse/default \ | ||
| data_split=$data_split\ | ||
| trainer=PDU\ | ||
| trainer.args.num_train_epochs=$num_train_epochs\ | ||
| trainer.args.eval_on_start=false trainer.args.do_eval=false\ | ||
| trainer.args.per_device_train_batch_size=$per_device_train_batch_size\ | ||
| trainer.args.learning_rate=$learning_rate\ | ||
| trainer.method_args.gamma=1. trainer.method_args.alpha=$pref\ | ||
| trainer.method_args.primal_dual=true trainer.method_args.retain_loss_eps=$retain_loss_eps\ | ||
| trainer.method_args.dual_step_size=$dual_step_size\ | ||
| trainer.method_args.dual_update_upon="step" trainer.method_args.dual_warmup_epochs=$dual_warmup_epochs\ | ||
| task_name=$task_name\ | ||
| model=$model model.model_args.pretrained_model_name_or_path=$pretrained_model_name_or_path | ||
| done | ||
| done | ||
| done |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| MUSE_train: | ||
| handler: PretrainingDataset | ||
| args: | ||
| hf_args: | ||
| path: "tamarsonha/MUSE-News-Train" | ||
| split: "full" | ||
| text_key: "text" | ||
| max_length: 2048 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - override /model: Llama-2-13b-hf | ||
| - override /trainer: finetune | ||
| - override /data/[email protected]: MUSE_train | ||
| - override /eval: muse | ||
| - override /data: finetune | ||
|
|
||
| mode: finetune | ||
| data_split: News | ||
| data_sub_set: full # full or retain | ||
|
|
||
| data: | ||
| train: | ||
| MUSE_train: | ||
| args: | ||
| hf_args: | ||
| path: tamarsonha/MUSE-${data_split}-Train | ||
| split: ${data_sub_set} | ||
| # you can find fine-tuned models on https://huggingface.co/tamarsonha | ||
|
|
||
| trainer: | ||
| args: | ||
| learning_rate: 1e-5 | ||
| weight_decay: 0.01 | ||
| warmup_epochs: 1.0 # custom parameter | ||
| num_train_epochs: 10 | ||
|
|
||
| task_name: muse_news_full |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| model_args: | ||
| pretrained_model_name_or_path: "meta-llama/Llama-2-13b-hf" | ||
| attn_implementation: 'flash_attention_2' | ||
| torch_dtype: bfloat16 | ||
| tokenizer_args: | ||
| pretrained_model_name_or_path: "meta-llama/Llama-2-13b-hf" | ||
| template_args: # Used in creating prompts for the dataset. See src/data/utils.py#preprocess_chat_instance. | ||
| apply_chat_template: False | ||
| user_start_tag: "Question: " | ||
| user_end_tag: "\n" | ||
| asst_start_tag: "Answer: " | ||
| asst_end_tag: "\n\n" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| defaults: | ||
| - GradDiff | ||
|
|
||
| handler: PDU | ||
| method_args: | ||
| retain_loss_eps: ??? | ||
| primal_dual: True | ||
| dual_step_size: 1.0 | ||
| dual_update_upon: "step" # "step" or "epoch" | ||
| dual_warmup_epochs: 5 | ||
| gamma: 1.0 | ||
| alpha: 1.0 | ||
| loss_names: ["forget_loss", "retain_loss"] | ||
|
|
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont think we need to add new arguments in default trainingargs. you can add new arguments using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am aware of the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we revert this? If people sets |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,4 +21,4 @@ args: | |
| eval_on_start: True | ||
| eval_strategy: epoch | ||
| num_train_epochs: 10 | ||
| seed: 0 | ||
| seed: 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to me this should be a metric. @Dornavineeth thoughts? |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -270,3 +270,4 @@ simple_evaluate_args: | |
| system_instruction: null | ||
| apply_chat_template: false | ||
| ``` | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. depends on resolution of other comments |
Uh oh!
There was an error while loading. Please reload this page.