Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
"phonemizer",
"opencv-python",
"timm",
"flashpack",
]

# this is a lookup table with items like:
Expand Down Expand Up @@ -248,6 +249,7 @@ def run(self):
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
extras["flashpack"] = deps_list("flashpack")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
"timm": "timm",
"flashpack": "flashpack",
}
31 changes: 31 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..quantizers import DiffusersQuantizer
from ..utils import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
FLASHPACK_FILE_EXTENSION,
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand All @@ -42,6 +43,7 @@
deprecate,
is_accelerate_available,
is_accelerate_version,
is_flashpack_available,
is_gguf_available,
is_torch_available,
is_torch_version,
Expand Down Expand Up @@ -177,6 +179,8 @@ def load_state_dict(
return safetensors.torch.load_file(checkpoint_file, device=map_location)
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
elif file_extension == FLASHPACK_FILE_EXTENSION:
return load_flashpack_checkpoint(checkpoint_file)
else:
extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
Expand Down Expand Up @@ -682,6 +686,33 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
return parsed_parameters


def load_flashpack_checkpoint(flashpack_checkpoint_path: str):
"""
Load a FlashPack file and return a dictionary of parsed parameters containing tensors.

Args:
flashpack_checkpoint_path (`str`):
The path the to FlashPack file to load
"""

if is_flashpack_available() and is_torch_available():
import flashpack
else:
logger.error(
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")

flash_tensor, meta = flashpack.deserialization.read_flashpack_file(
path=flashpack_checkpoint_path,
)
state_dict = {}
for name, view in flashpack.deserialization.iterate_from_flash_tensor(flash_tensor, meta):
state_dict[name] = view
return state_dict


def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
Expand Down
189 changes: 140 additions & 49 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME,
Expand All @@ -55,6 +56,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_flashpack_available,
is_peft_available,
is_torch_version,
logging,
Expand Down Expand Up @@ -652,6 +654,7 @@ def save_pretrained(
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -704,7 +707,12 @@ def save_pretrained(
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = WEIGHTS_NAME
if use_flashpack:
weights_name = FLASHPACK_WEIGHTS_NAME
elif safe_serialization:
weights_name = SAFETENSORS_WEIGHTS_NAME

weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
Expand All @@ -731,58 +739,74 @@ def save_pretrained(
# Save the model
state_dict = model_to_save.state_dict()

# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)

# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)

for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
if use_flashpack:
if is_flashpack_available():
import flashpack
else:
torch.save(shard, filepath)
logger.error(
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")

if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
flashpack.serialization.pack_to_file(
state_dict_or_model=state_dict,
destination_path=os.path.join(save_directory, weights_name),
target_dtype=self.dtype,
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)

# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)

for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)

if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")

if push_to_hub:
# Create a new empty model card and eventually tag it
Expand Down Expand Up @@ -919,6 +943,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is loaded from `flashpack` weights.
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
Kwargs passed to [`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)


> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
with `hf > auth login`. You can also activate the special >
Expand Down Expand Up @@ -963,6 +992,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})

is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage:
Expand Down Expand Up @@ -1191,6 +1222,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
elif use_flashpack:
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLASHPACK_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
elif use_safetensors:
try:
resolved_model_file = _get_model_file(
Expand Down Expand Up @@ -1254,6 +1309,42 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with ContextManagers(init_contexts):
model = cls.from_config(config, **unused_kwargs)

if use_flashpack:
if is_flashpack_available():
import flashpack
else:
logger.error(
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")

if device_map is None:
logger.warning(
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
"the benefit of FlashPack."
)
flashpack_device = None
else:
flashpack_device = device_map[""]
if flashpack_device in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"FlashPack `device_map` should be a device, not one of `auto`, `balanced`, `balanced_low_0`, `sequential`."
)

flashpack.mixin.assign_from_file(
model=model,
path=resolved_model_file[0],
device=flashpack_device,
**flashpack_kwargs,
)

if output_loading_info:
logger.warning("`output_loading_info` is not supported with FlashPack.")
return model, {}

return model

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

Expand Down
Loading