Skip to content
17 changes: 15 additions & 2 deletions src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch

from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention


_ATTENTION_CLASSES = (Attention, MochiAttention)
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
Expand All @@ -35,6 +35,19 @@
}
)

# Layers supported for group offloading and layerwise casting
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)


def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
Expand Down
14 changes: 7 additions & 7 deletions src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook


Expand All @@ -30,7 +30,6 @@

_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
Expand Down Expand Up @@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.

Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline to apply FasterCache to.
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
module (`torch.nn.Module`):
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FasterCacheConfig`):
The configuration to use for FasterCache.

Example:
Expand Down Expand Up @@ -568,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
_apply_faster_cache_on_denoiser(module, config)

for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
Expand All @@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)


def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
Expand Down
32 changes: 32 additions & 0 deletions src/diffusers/hooks/first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,38 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):


def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
"""
Applies [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
to a given module.

First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
to implement generically for a wide range of models and has been integrated first for experimental purposes.

Args:
module (`torch.nn.Module`):
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FirstBlockCacheConfig`):
The configuration to use for applying the FBCache method.

Example:
```python
>>> import torch
>>> from diffusers import CogView4Pipeline
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig

>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))

>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
>>> image.save("output.png")
```
"""

state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []

Expand Down
10 changes: 2 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from ..utils import get_logger, is_accelerate_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook


Expand All @@ -39,13 +40,6 @@
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)
# fmt: on


Expand Down Expand Up @@ -683,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
Expand Down
9 changes: 2 additions & 7 deletions src/diffusers/hooks/layerwise_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from ..utils import get_logger, is_peft_available, is_peft_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook


Expand All @@ -27,12 +28,6 @@
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)

DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on

Expand Down Expand Up @@ -186,7 +181,7 @@ def _apply_layerwise_casting(
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return

if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
Expand Down
22 changes: 12 additions & 10 deletions src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from ._common import (
_ATTENTION_CLASSES,
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
)
from .hooks import HookRegistry, ModelHook


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")


@dataclass
Expand Down Expand Up @@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""

Expand All @@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)

spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS

current_timestep_callback: Callable[[], int] = None

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,9 +1394,9 @@ def get_device_properties() -> DeviceProperties:
DevicePropertiesUserDict = UserDict

if is_torch_available():
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.group_offloading import (
_GROUP_ID_LAZY_LEAF,
_SUPPORTED_PYTORCH_LAYERS,
_compute_group_hash,
_find_parent_module_in_module_dict,
_gather_buffers_with_no_group_offloading_parent,
Expand Down Expand Up @@ -1440,13 +1440,13 @@ def get_hashed_filename(group_id: str) -> str:
elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))

# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
Expand Down
9 changes: 5 additions & 4 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down Expand Up @@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""

from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)

Expand All @@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,8 @@ def test_fn(storage_dtype, compute_dtype):

@torch.no_grad()
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1544,7 +1545,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down
Loading