diff --git a/mmgpt/train/instruction_finetune.py b/mmgpt/train/instruction_finetune.py index 4ece1bf..f429288 100644 --- a/mmgpt/train/instruction_finetune.py +++ b/mmgpt/train/instruction_finetune.py @@ -311,7 +311,7 @@ def apply_decay(x): "model_state_dict": get_checkpoint(ddp_model), "optimizer_state_dict": optimizer.state_dict(), "lr_scheduler_state_dict": lr_scheduler.state_dict(), - "tuning_config": tuning_config, + "tuning_config": tuning_config.tuning_config, } print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") @@ -324,7 +324,7 @@ def apply_decay(x): os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") if args.rank == 0: torch.save( - {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config}, + {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config.tuning_config}, f"{args.run_name}/final_weights.pt", ) if args.report_to_wandb and args.save_checkpoints_to_wandb: