Skip to content

Commit 6f3ac30

Browse files
authored
[refactor] some shared parts between hooks + docs (#11968)
* update * try test fix * add missing link * fix tests * Update src/diffusers/hooks/first_block_cache.py * make style
1 parent a6d9f6a commit 6f3ac30

File tree

9 files changed

+81
-43
lines changed

9 files changed

+81
-43
lines changed

src/diffusers/hooks/_common.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import torch
1818

19-
from ..models.attention import FeedForward, LuminaFeedForward
19+
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
2020
from ..models.attention_processor import Attention, MochiAttention
2121

2222

23-
_ATTENTION_CLASSES = (Attention, MochiAttention)
23+
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
2424
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
2525

2626
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
@@ -35,6 +35,19 @@
3535
}
3636
)
3737

38+
# Layers supported for group offloading and layerwise casting
39+
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
40+
torch.nn.Conv1d,
41+
torch.nn.Conv2d,
42+
torch.nn.Conv3d,
43+
torch.nn.ConvTranspose1d,
44+
torch.nn.ConvTranspose2d,
45+
torch.nn.ConvTranspose3d,
46+
torch.nn.Linear,
47+
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
48+
# because of double invocation of the same norm layer in CogVideoXLayerNorm
49+
)
50+
3851

3952
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
4053
for submodule_name, submodule in module.named_modules():

src/diffusers/hooks/faster_cache.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import torch
2020

2121
from ..models.attention import AttentionModuleMixin
22-
from ..models.attention_processor import Attention, MochiAttention
2322
from ..models.modeling_outputs import Transformer2DModelOutput
2423
from ..utils import logging
24+
from ._common import _ATTENTION_CLASSES
2525
from .hooks import HookRegistry, ModelHook
2626

2727

@@ -30,7 +30,6 @@
3030

3131
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
3232
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
33-
_ATTENTION_CLASSES = (Attention, MochiAttention)
3433
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
3534
"^blocks.*attn",
3635
"^transformer_blocks.*attn",
@@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
489488
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
490489
491490
Args:
492-
pipeline (`DiffusionPipeline`):
493-
The diffusion pipeline to apply FasterCache to.
494-
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
491+
module (`torch.nn.Module`):
492+
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
493+
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
494+
config (`FasterCacheConfig`):
495495
The configuration to use for FasterCache.
496496
497497
Example:
@@ -568,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
568568
_apply_faster_cache_on_denoiser(module, config)
569569

570570
for name, submodule in module.named_modules():
571-
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
571+
if not isinstance(submodule, _ATTENTION_CLASSES):
572572
continue
573573
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
574574
_apply_faster_cache_on_attention_class(name, submodule, config)
@@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
589589
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
590590

591591

592-
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
592+
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
593593
is_spatial_self_attention = (
594594
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
595595
and config.spatial_attention_block_skip_range is not None

src/diffusers/hooks/first_block_cache.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,38 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
192192

193193

194194
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
195+
"""
196+
Applies [First Block
197+
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
198+
to a given module.
199+
200+
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
201+
to implement generically for a wide range of models and has been integrated first for experimental purposes.
202+
203+
Args:
204+
module (`torch.nn.Module`):
205+
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
206+
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
207+
config (`FirstBlockCacheConfig`):
208+
The configuration to use for applying the FBCache method.
209+
210+
Example:
211+
```python
212+
>>> import torch
213+
>>> from diffusers import CogView4Pipeline
214+
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
215+
216+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
217+
>>> pipe.to("cuda")
218+
219+
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
220+
221+
>>> prompt = "A photo of an astronaut riding a horse on mars"
222+
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
223+
>>> image.save("output.png")
224+
```
225+
"""
226+
195227
state_manager = StateManager(FBCSharedBlockState, (), {})
196228
remaining_blocks = []
197229

src/diffusers/hooks/group_offloading.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424

2525
from ..utils import get_logger, is_accelerate_available
26+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2627
from .hooks import HookRegistry, ModelHook
2728

2829

@@ -39,13 +40,6 @@
3940
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
4041
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
4142
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
42-
_SUPPORTED_PYTORCH_LAYERS = (
43-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
44-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
45-
torch.nn.Linear,
46-
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
47-
# because of double invocation of the same norm layer in CogVideoXLayerNorm
48-
)
4943
# fmt: on
5044

5145

@@ -683,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
683677
# Create module groups for leaf modules and apply group offloading hooks
684678
modules_with_group_offloading = set()
685679
for name, submodule in module.named_modules():
686-
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
680+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
687681
continue
688682
group = ModuleGroup(
689683
modules=[submodule],

src/diffusers/hooks/layerwise_casting.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from ..utils import get_logger, is_peft_available, is_peft_version
21+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2122
from .hooks import HookRegistry, ModelHook
2223

2324

@@ -27,12 +28,6 @@
2728
# fmt: off
2829
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
2930
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
30-
SUPPORTED_PYTORCH_LAYERS = (
31-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33-
torch.nn.Linear,
34-
)
35-
3631
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
3732
# fmt: on
3833

@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
186181
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
187182
return
188183

189-
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
184+
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
190185
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
191186
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
192187
return

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,19 @@
2121
from ..models.attention import AttentionModuleMixin
2222
from ..models.attention_processor import Attention, MochiAttention
2323
from ..utils import logging
24+
from ._common import (
25+
_ATTENTION_CLASSES,
26+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
29+
)
2430
from .hooks import HookRegistry, ModelHook
2531

2632

2733
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2834

2935

3036
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
31-
_ATTENTION_CLASSES = (Attention, MochiAttention)
32-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
33-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
34-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
3537

3638

3739
@dataclass
@@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
6163
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
6264
The range of timesteps to skip in the cross-attention layer. The attention computations will be
6365
conditionally skipped if the current timestep is within the specified range.
64-
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
66+
spatial_attention_block_identifiers (`Tuple[str, ...]`):
6567
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
66-
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
68+
temporal_attention_block_identifiers (`Tuple[str, ...]`):
6769
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
68-
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
70+
cross_attention_block_identifiers (`Tuple[str, ...]`):
6971
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
7072
"""
7173

@@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
7779
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7880
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7981

80-
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
81-
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
82-
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
82+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
83+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
84+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8385

8486
current_timestep_callback: Callable[[], int] = None
8587

src/diffusers/utils/testing_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,9 +1394,9 @@ def get_device_properties() -> DeviceProperties:
13941394
DevicePropertiesUserDict = UserDict
13951395

13961396
if is_torch_available():
1397+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
13971398
from diffusers.hooks.group_offloading import (
13981399
_GROUP_ID_LAZY_LEAF,
1399-
_SUPPORTED_PYTORCH_LAYERS,
14001400
_compute_group_hash,
14011401
_find_parent_module_in_module_dict,
14021402
_gather_buffers_with_no_group_offloading_parent,
@@ -1440,13 +1440,13 @@ def get_hashed_filename(group_id: str) -> str:
14401440
elif offload_type == "leaf_level":
14411441
# Handle leaf-level module groups
14421442
for name, submodule in module.named_modules():
1443-
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
1443+
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
14441444
# These groups will always have parameters, so a file is expected
14451445
expected_files.add(get_hashed_filename(name))
14461446

14471447
# Handle groups for non-leaf parameters/buffers
14481448
modules_with_group_offloading = {
1449-
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
1449+
name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
14501450
}
14511451
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
14521452
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

tests/lora/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self):
21092109
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
21102110

21112111
def test_layerwise_casting_inference_denoiser(self):
2112-
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
2112+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2113+
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
21132114

21142115
def check_linear_dtype(module, storage_dtype, compute_dtype):
21152116
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
21162117
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
21172118
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
21182119
for name, submodule in module.named_modules():
2119-
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
2120+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
21202121
continue
21212122
dtype_to_check = storage_dtype
21222123
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
21672168
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
21682169
"""
21692170

2171+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
21702172
from diffusers.hooks.layerwise_casting import (
21712173
_PEFT_AUTOCAST_DISABLE_HOOK,
21722174
DEFAULT_SKIP_MODULES_PATTERN,
2173-
SUPPORTED_PYTORCH_LAYERS,
21742175
apply_layerwise_casting,
21752176
)
21762177

@@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
21802181
def check_module(denoiser):
21812182
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
21822183
for name, module in denoiser.named_modules():
2183-
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
2184+
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
21842185
continue
21852186
dtype_to_check = storage_dtype
21862187
if any(re.search(pattern, name) for pattern in patterns_to_check):

tests/models/test_modeling_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,8 @@ def test_fn(storage_dtype, compute_dtype):
15301530

15311531
@torch.no_grad()
15321532
def test_layerwise_casting_inference(self):
1533-
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
1533+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
1534+
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
15341535

15351536
torch.manual_seed(0)
15361537
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1544,7 +1545,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
15441545
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
15451546
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
15461547
for name, submodule in module.named_modules():
1547-
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
1548+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
15481549
continue
15491550
dtype_to_check = storage_dtype
15501551
if any(re.search(pattern, name) for pattern in patterns_to_check):

0 commit comments

Comments
 (0)