@@ -1064,6 +1064,41 @@ def save_function(weights, filename):
1064
1064
save_function (state_dict , save_path )
1065
1065
logger .info (f"Model weights saved in { save_path } " )
1066
1066
1067
+ @classmethod
1068
+ def _save_lora_weights (
1069
+ cls ,
1070
+ save_directory : Union [str , os .PathLike ],
1071
+ lora_layers : Dict [str , Dict [str , Union [torch .nn .Module , torch .Tensor ]]],
1072
+ lora_metadata : Dict [str , Optional [dict ]],
1073
+ is_main_process : bool = True ,
1074
+ weight_name : str = None ,
1075
+ save_function : Callable = None ,
1076
+ safe_serialization : bool = True ,
1077
+ ):
1078
+ """
1079
+ Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
1080
+ pipeline types.
1081
+ """
1082
+ state_dict = {}
1083
+ final_lora_adapter_metadata = {}
1084
+
1085
+ for prefix , layers in lora_layers .items ():
1086
+ state_dict .update (cls .pack_weights (layers , prefix ))
1087
+
1088
+ for prefix , metadata in lora_metadata .items ():
1089
+ if metadata :
1090
+ final_lora_adapter_metadata .update (_pack_dict_with_prefix (metadata , prefix ))
1091
+
1092
+ cls .write_lora_layers (
1093
+ state_dict = state_dict ,
1094
+ save_directory = save_directory ,
1095
+ is_main_process = is_main_process ,
1096
+ weight_name = weight_name ,
1097
+ save_function = save_function ,
1098
+ safe_serialization = safe_serialization ,
1099
+ lora_adapter_metadata = final_lora_adapter_metadata if final_lora_adapter_metadata else None ,
1100
+ )
1101
+
1067
1102
@classmethod
1068
1103
def _optionally_disable_offloading (cls , _pipeline ):
1069
1104
return _func_optionally_disable_offloading (_pipeline = _pipeline )
0 commit comments