diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 08f474fc1cc7..ca7934e5c313 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -16,11 +16,11 @@ import torch -from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention -_ATTENTION_CLASSES = (Attention, MochiAttention) +_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") @@ -35,6 +35,19 @@ } ) +# Layers supported for group offloading and layerwise casting +_GO_LC_SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + torch.nn.Linear, + # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX + # because of double invocation of the same norm layer in CogVideoXLayerNorm +) + def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: for submodule_name, submodule in module.named_modules(): diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index a6c250b50ca4..53e5bd792c6a 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -19,9 +19,9 @@ import torch from ..models.attention import AttentionModuleMixin -from ..models.attention_processor import Attention, MochiAttention from ..models.modeling_outputs import Transformer2DModelOutput from ..utils import logging +from ._common import _ATTENTION_CLASSES from .hooks import HookRegistry, ModelHook @@ -30,7 +30,6 @@ _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" -_ATTENTION_CLASSES = (Attention, MochiAttention) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "^blocks.*attn", "^transformer_blocks.*attn", @@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. Args: - pipeline (`DiffusionPipeline`): - The diffusion pipeline to apply FasterCache to. - config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): + module (`torch.nn.Module`): + The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported + in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work. + config (`FasterCacheConfig`): The configuration to use for FasterCache. Example: @@ -568,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float: _apply_faster_cache_on_denoiser(module, config) for name, submodule in module.named_modules(): - if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + if not isinstance(submodule, _ATTENTION_CLASSES): continue if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): _apply_faster_cache_on_attention_class(name, submodule, config) @@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK) -def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None: +def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None: is_spatial_self_attention = ( any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 40ae8c5a263a..862d44059301 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -192,6 +192,38 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + """ + Applies [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching) + to a given module. + + First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler + to implement generically for a wide range of models and has been integrated first for experimental purposes. + + Args: + module (`torch.nn.Module`): + The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in + Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work. + config (`FirstBlockCacheConfig`): + The configuration to use for applying the FBCache method. + + Example: + ```python + >>> import torch + >>> from diffusers import CogView4Pipeline + >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig + + >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2)) + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0] + >>> image.save("output.png") + ``` + """ + state_manager = StateManager(FBCSharedBlockState, (), {}) remaining_blocks = [] diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 1248bedf861c..3015409afc9a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -23,6 +23,7 @@ import torch from ..utils import get_logger, is_accelerate_available +from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -39,13 +40,6 @@ _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" _GROUP_ID_LAZY_LEAF = "lazy_leafs" -_SUPPORTED_PYTORCH_LAYERS = ( - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, - torch.nn.Linear, - # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX - # because of double invocation of the same norm layer in CogVideoXLayerNorm -) # fmt: on @@ -683,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): - if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue group = ModuleGroup( modules=[submodule], diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py index 1747a5c489cb..a036ad37dc2f 100644 --- a/src/diffusers/hooks/layerwise_casting.py +++ b/src/diffusers/hooks/layerwise_casting.py @@ -18,6 +18,7 @@ import torch from ..utils import get_logger, is_peft_available, is_peft_version +from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -27,12 +28,6 @@ # fmt: off _LAYERWISE_CASTING_HOOK = "layerwise_casting" _PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable" -SUPPORTED_PYTORCH_LAYERS = ( - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, - torch.nn.Linear, -) - DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$") # fmt: on @@ -186,7 +181,7 @@ def _apply_layerwise_casting( logger.debug(f'Skipping layerwise casting for layer "{_prefix}"') return - if isinstance(module, SUPPORTED_PYTORCH_LAYERS): + if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): logger.debug(f'Applying layerwise casting to layer "{_prefix}"') apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking) return diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 1c8787194196..ee3f41033171 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -21,6 +21,12 @@ from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..utils import logging +from ._common import ( + _ATTENTION_CLASSES, + _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, +) from .hooks import HookRegistry, ModelHook @@ -28,10 +34,6 @@ _PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast" -_ATTENTION_CLASSES = (Attention, MochiAttention) -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) -_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") @dataclass @@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig: cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): The range of timesteps to skip in the cross-attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + spatial_attention_block_identifiers (`Tuple[str, ...]`): The identifiers to match against the layer names to determine if the layer is a spatial attention layer. - temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): + temporal_attention_block_identifiers (`Tuple[str, ...]`): The identifiers to match against the layer names to determine if the layer is a temporal attention layer. - cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + cross_attention_block_identifiers (`Tuple[str, ...]`): The identifiers to match against the layer names to determine if the layer is a cross-attention layer. """ @@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig: temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) - spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS - temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS - cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS current_timestep_callback: Callable[[], int] = None diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 0bc1690658e7..3d9444975d99 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1394,9 +1394,9 @@ def get_device_properties() -> DeviceProperties: DevicePropertiesUserDict = UserDict if is_torch_available(): + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.group_offloading import ( _GROUP_ID_LAZY_LEAF, - _SUPPORTED_PYTORCH_LAYERS, _compute_group_hash, _find_parent_module_in_module_dict, _gather_buffers_with_no_group_offloading_parent, @@ -1440,13 +1440,13 @@ def get_hashed_filename(group_id: str) -> str: elif offload_type == "leaf_level": # Handle leaf-level module groups for name, submodule in module.named_modules(): - if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): # These groups will always have parameters, so a file is expected expected_files.add(get_hashed_filename(name)) # Handle groups for non-leaf parameters/buffers modules_with_group_offloading = { - name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) + name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS) } parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 56f390f54ad8..9edaeafc71d7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) def test_layerwise_casting_inference_denoiser(self): - from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: patterns_to_check += tuple(module._skip_layerwise_casting_patterns) for name, submodule in module.named_modules(): - if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): @@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. """ + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import ( _PEFT_AUTOCAST_DISABLE_HOOK, DEFAULT_SKIP_MODULES_PATTERN, - SUPPORTED_PYTORCH_LAYERS, apply_layerwise_casting, ) @@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): def check_module(denoiser): # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) for name, module in denoiser.named_modules(): - if not isinstance(module, SUPPORTED_PYTORCH_LAYERS): + if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype if any(re.search(pattern, name) for pattern in patterns_to_check): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 435bd32c6083..36b563ba9f8e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1530,7 +1530,8 @@ def test_fn(storage_dtype, compute_dtype): @torch.no_grad() def test_layerwise_casting_inference(self): - from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1544,7 +1545,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: patterns_to_check += tuple(module._skip_layerwise_casting_patterns) for name, submodule in module.named_modules(): - if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype if any(re.search(pattern, name) for pattern in patterns_to_check):