Skip to content

Conversation

@hlky
Copy link
Contributor

@hlky hlky commented Nov 22, 2025

What does this PR do?

Adds support for FlashPack

FlashPack could be used as weight format only (see: load_flashpack_checkpoint) - keeping only weight format code would be a cleaner integration, the model loading is indeed faster[1] however part of the performance difference seems to be just due to the complexity of existing from_pretrained code, for example I noticed that _caching_allocator_warmup slows things down, also empty_device_cache is called in _load_pretrained_model whereas FlashPack's code doesn't thus the empty cache time is excluded from FlashPack's benchmark results.

  • Pipeline
  • Cleanup (is_flashpack_available check logging, etc)
  • FlashPack arguments (silent, num_streams, use_distributed_loading etc.)
Benchmark

Changes from FlashPack's version:

  • Use Diffusers
  • Remove transformers related code
  • Add SD v1.5
import csv
import gc
import os
import shutil
import tempfile
import time

import torch

from diffusers.models import AutoModel as DiffusersAutoModel
from huggingface_hub import snapshot_download


def test_model(
    repo_id: str,
    subfolder: str | None = None,
    accelerate_device: str | torch.device = "cuda",
    flashpack_device: str | torch.device = "cuda",
    dtype: torch.dtype | None = None,
    allow_pattern: str = "*",
) -> tuple[float, float, int]:
    """
    Test a model from a repository.
    """
    allow_patterns = [f"{subfolder}/{allow_pattern}"]
    if allow_pattern != "*":
        allow_patterns = [f"{subfolder}/{allow_pattern}", f"{subfolder}/config.json"]
    repo_dir = snapshot_download(
        repo_id, allow_patterns=None if subfolder is None else allow_patterns
    )
    model_dir = repo_dir if subfolder is None else os.path.join(repo_dir, subfolder)
    saved_flashpack_path = os.path.join(model_dir, "model.flashpack")
    saved_flashpack_config_path = os.path.join(model_dir, "flashpack_config.json")

    with tempfile.TemporaryDirectory() as tmpdir:
        # Make a new model directory with the model in it so it isn't cached
        temp_model_dir = os.path.join(tmpdir, "model")
        flashpack_dir = os.path.join(tmpdir, "flashpack")
        os.makedirs(flashpack_dir, exist_ok=True)
        print("Copying model to temporary directory")
        shutil.copytree(model_dir, temp_model_dir)

        # Load from the temporary model directory
        print("Loading model from temporary directory using from_pretrained")
        start_time = time.time()
        model = DiffusersAutoModel.from_pretrained(
            temp_model_dir,
            torch_dtype=dtype,
            device_map={"": accelerate_device},
            variant="fp16" if allow_pattern != "*" else None,
        )

        end_time = time.time()
        accelerate_time = end_time - start_time
        print(f"Time taken with from_pretrained: {accelerate_time} seconds")

        if os.path.exists(saved_flashpack_path) and os.path.exists(
            saved_flashpack_config_path
        ):
            print("Copying flashpack to temporary directory")
            shutil.copy(
                saved_flashpack_path, os.path.join(flashpack_dir, "model.flashpack")
            )
            shutil.copy(
                saved_flashpack_config_path, os.path.join(flashpack_dir, "config.json")
            )
        else:
            print("Packing model to flashpack")
            pack_start_time = time.time()
            model.save_pretrained(
                flashpack_dir,
                use_flashpack=True,
            )
            pack_end_time = time.time()
            print(
                f"Time taken with save_pretrained_flashpack: {pack_end_time - pack_start_time} seconds"
            )
            # Copy back to the original model directory
            shutil.copy(
                os.path.join(flashpack_dir, "model.flashpack"), saved_flashpack_path
            )
            shutil.copy(
                os.path.join(flashpack_dir, "config.json"), saved_flashpack_config_path
            )

        del model
        sync_and_flush()

        print("Loading model from flashpack directory using from_pretrained_flashpack")
        flashpack_start_time = time.time()
        flashpack_model = DiffusersAutoModel.from_pretrained(
            flashpack_dir,
            torch_dtype=dtype,
            device_map={"": flashpack_device},
            use_flashpack=True,
        )

        flashpack_end_time = time.time()
        flashpack_time = flashpack_end_time - flashpack_start_time
        print(f"Time taken with from_pretrained_flashpack: {flashpack_time} seconds")

        total_numel = 0
        for param in flashpack_model.parameters():
            total_numel += param.numel()

        total_bytes = total_numel * dtype.itemsize

        del flashpack_model
        sync_and_flush()

        return accelerate_time, flashpack_time, total_bytes


def test_wan_small_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )


def test_wan_large_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="Wan-AI/Wan2.1-T2V-14B-Diffusers",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )

def test_stable_diffusion_v1_5() -> tuple[float, float, int]:
    return test_model(
        repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5",
        subfolder="unet",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
        dtype=torch.float16,
        allow_pattern="*.fp16.safetensors",
    )

def test_flux_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="black-forest-labs/FLUX.1-dev",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )


def print_test_result(
    model_name: str,
    accelerate_time: float,
    flashpack_time: float,
    total_bytes: int,
) -> None:
    print(f"{model_name}: Accelerate time: {accelerate_time} seconds")
    print(f"{model_name}: Flashpack time: {flashpack_time} seconds")
    accelerate_gbps = (total_bytes / 1000**3) / accelerate_time
    flashpack_gbps = (total_bytes / 1000**3) / flashpack_time
    print(f"{model_name}: Accelerate GB/s: {accelerate_gbps} GB/s")
    print(f"{model_name}: Flashpack GB/s: {flashpack_gbps} GB/s")


def sync_and_flush() -> None:
    torch.cuda.empty_cache()
    gc.collect()
    os.system("sync")
    if os.geteuid() == 0:
        os.system("echo 3 | tee /proc/sys/vm/drop_caches")


if __name__ == "__main__":
    with open("benchmark_results.csv", "a") as f:
        writer = csv.writer(f)
        writer.writerow(["model", "accelerate_time", "flashpack_time", "total_bytes"])
        for i in range(10):
            for test_model_name, test_func in [
                ("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", test_wan_small_transformer),
                ("stable-diffusion-v1-5/stable-diffusion-v1-5", test_stable_diffusion_v1_5),
                # ("black-forest-labs/FLUX.1-dev", test_flux_transformer),
            ]:
                accelerate_time, flashpack_time, total_bytes = test_func()
                writer.writerow(
                    [test_model_name, accelerate_time, flashpack_time, total_bytes]
                )
                print_test_result(
                    test_model_name, accelerate_time, flashpack_time, total_bytes
                )

======================================================================
SUMMARY STATISTICS
======================================================================

Model                                    Size (GB)    Accel (s)    Flash (s)    Speedup   
----------------------------------------------------------------------
Stable Diffusion v1.5  (fp16)                        1.60       0.250       0.263      0.95x
Wan2.1 1.3B DiT                                2.64       1.315       0.600      2.19x

======================================================================

Model                                    Accel GB/s      Flash GB/s     
----------------------------------------------------------------------
Stable Diffusion v1.5  (fp16)                           6.41           6.13
Wan2.1 1.3B DiT                                   2.33           4.44
======================================================================
benchmark_comparison

[1] For bfloat16 - with float16 existing code appears to be faster

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant