Skip to content

Commit d8310a8

Browse files
authored
[lora] factor out the overlaps in save_lora_weights(). (#12027)
* factor out the overlaps in save_lora_weights(). * remove comment. * remove comment. * up * fix-copies
1 parent 78031c2 commit d8310a8

File tree

2 files changed

+205
-256
lines changed

2 files changed

+205
-256
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,41 @@ def save_function(weights, filename):
10641064
save_function(state_dict, save_path)
10651065
logger.info(f"Model weights saved in {save_path}")
10661066

1067+
@classmethod
1068+
def _save_lora_weights(
1069+
cls,
1070+
save_directory: Union[str, os.PathLike],
1071+
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
1072+
lora_metadata: Dict[str, Optional[dict]],
1073+
is_main_process: bool = True,
1074+
weight_name: str = None,
1075+
save_function: Callable = None,
1076+
safe_serialization: bool = True,
1077+
):
1078+
"""
1079+
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
1080+
pipeline types.
1081+
"""
1082+
state_dict = {}
1083+
final_lora_adapter_metadata = {}
1084+
1085+
for prefix, layers in lora_layers.items():
1086+
state_dict.update(cls.pack_weights(layers, prefix))
1087+
1088+
for prefix, metadata in lora_metadata.items():
1089+
if metadata:
1090+
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
1091+
1092+
cls.write_lora_layers(
1093+
state_dict=state_dict,
1094+
save_directory=save_directory,
1095+
is_main_process=is_main_process,
1096+
weight_name=weight_name,
1097+
save_function=save_function,
1098+
safe_serialization=safe_serialization,
1099+
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
1100+
)
1101+
10671102
@classmethod
10681103
def _optionally_disable_offloading(cls, _pipeline):
10691104
return _func_optionally_disable_offloading(_pipeline=_pipeline)

0 commit comments

Comments
 (0)