Skip to content

Conversation

@rycerzes
Copy link

What does this PR do?

Fixes #12319

This PR fixes the device mismatch error that occurs when using block_level group offloading with models containing standalone computational layers (like VAE's post_quant_conv and quant_conv).

Problem

When using block_level offloading, the implementation only matched ModuleList and Sequential containers, leaving standalone layers (like Conv2d) unmanaged. These layers remained on CPU while their inputs were on CUDA, causing:

RuntimeError: Input type (CUDABFloat16Type) and weight type (CPUBFloat16Type) should be the same

Block-level offloading logic in group_offloading.py only looked for ModuleList and Sequential containers when creating groups. Standalone computational layers like the VAE's post_quant_conv (a Conv2d layer) were not included in any group, so they never received hooks to manage their device placement. This caused them to remain on CPU while their inputs were transferred to CUDA.

The block-level offloading logic has been modified to:

  1. Identify standalone computational layers that don't belong to any ModuleList/Sequential container
  2. Group them into an unmatched_group that gets proper hook management
  3. Apply hooks to ensure proper device placement during forward pass

Key Changes:

  • Updated _create_groups_for_block_level_offloading() function to collect unmatched computational layers
  • Added logic to create a group for standalone layers using the same _GO_LC_SUPPORTED_PYTORCH_LAYERS filter
  • Ensured the unmatched group is properly integrated into the hook chain

Testing

Tested with:

  • SDXL VAE (AutoencoderKL) which has standalone post_quant_conv and quant_conv layers
  • Created test cases for models with both standalone and deeply nested layer structures
  • Confirmed both streaming and non-streaming modes work correctly

Test Coverage:

  • test_group_offloading_models_with_standalone_and_deeply_nested_layers - Verifies the fix works with complex model architectures
  • All existing group offloading tests continue to pass

Expected Behavior After Fix

Before: Block-level offloading fails with device mismatch error when models have standalone computational layers

After: Block-level offloading works correctly with all model architectures, including those with:

  • Standalone Conv2d, Linear, and other computational layers
  • Nested ModuleList/Sequential containers
  • Mixed architectures with both standalone and containerized layers

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul this closes #12319

@sayakpaul
Copy link
Member

@vladmandic would you be interested in testing this out a bit?

@rycerzes
Copy link
Author

@sayakpaul this patch should also fix #12096 since both have the same root cause (standalone conv layers not tracked in block-level offloading), and this handles both Conv2d (SDXL) and Conv3d (Wan).

The fix should work for WanVACEPipeline as well.

@sayakpaul
Copy link
Member

Very cool! Feel free to add some lightweight tests to this PR and the final outputs so that we can also test ourselves.

@sayakpaul sayakpaul requested a review from DN6 November 21, 2025 07:06
@rycerzes
Copy link
Author

Very cool! Feel free to add some lightweight tests to this PR and the final outputs so that we can also test ourselves.

Yes, I added tests in test_group_offloading.py covering the core fix test_block_level_stream_with_invocation_order_different_from_initialization_order plus edge cases for VAE-like models with standalone layers, deeply nested structures, and parameter-only modules.

I also created a standalone test script that validates SDXL VAE and AutoencoderKLWan with both block_level and leaf_level offloading. Output of the script.

Pytest output

pytest tests/hooks/test_group_offloading.py -v
==================================================================== test session starts =====================================================================
platform win32 -- Python 3.13.3, pytest-9.0.1, pluggy-1.6.0 -- D:\Github\oss\diffusers\.venv\Scripts\python.exe
cachedir: .pytest_cache
rootdir: D:\Github\oss\diffusers
configfile: pyproject.toml
plugins: anyio-4.11.0, timeout-2.4.0, xdist-3.8.0, requests-mock-1.10.0
collected 20 items                                                                                                                                            

tests/hooks/test_group_offloading.py::GroupOffloadTests::test_block_level_offloading_with_parameter_only_module_group_0_block_level PASSED              [  5%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_block_level_offloading_with_parameter_only_module_group_1_leaf_level PASSED               [ 10%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_block_level_stream_with_invocation_order_different_from_initialization_order PASSED       [ 15%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_group_offloading_applied_on_model_offloaded_module PASSED                 [ 20%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module PASSED            [ 25%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_model_offloading_applied_on_group_offloaded_module PASSED                 [ 30%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module PASSED            [ 35%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_streams_used_and_no_accelerator_device PASSED                             [ 40%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_error_raised_if_supports_group_offloading_false PASSED                                    [ 45%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_model_with_deeply_nested_blocks PASSED                                                    [ 50%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_model_with_only_standalone_layers PASSED                                                  [ 55%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_multiple_invocations_with_vae_like_model PASSED                                           [ 60%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_nested_container_parameters_offloading PASSED                                             [ 65%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_offloading_forward_pass PASSED                                                            [ 70%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_standalone_conv_layers_with_both_offload_types_0_block_level PASSED                       [ 75%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_standalone_conv_layers_with_both_offload_types_1_leaf_level PASSED                        [ 80%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_vae_like_model_with_standalone_conv_layers PASSED                                         [ 85%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_vae_like_model_without_streams PASSED                                                     [ 90%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_warning_logged_if_group_offloaded_module_moved_to_accelerator PASSED                      [ 95%]
tests/hooks/test_group_offloading.py::GroupOffloadTests::test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator PASSED                        [100%]

===================================================================== 20 passed in 4.34s =====================================================================

@sayakpaul
Copy link
Member

Thanks for the comprehensive testing! I meant to ask for an even more minimal test script that utilizes group offloading with block_level and generates an output as expected. Something like:

from diffusers import DiffusionPipeline
import torch 

pipe = DiffusionPipeline.from_pretrained("...", torch_dtype=torch.bfloat16)
pipe.transformer.enable_group_offloading(...)

# move rest of the components to CUDA
...

# inference
pipe(...)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting started on the PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Nov 24, 2025

Hi @rycerzes thank you for taking the time to put this together.

I'm not sure this is the most ideal solution for the problem. The solution works for the top level modules in the VAEs, but if we apply this recursive fallback to transformer models like Flux, it would result in hooks for context_embedder, x_embedder etc. As opposed to having them be in a single unmatched group.

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)

Rather than silently falling back to leaf offloading (which isn't the behaviour expected), I would recommend a simpler approach that directly identifies blocks in the Autoencoder for offloading (since they are a bit of an edge case for block offloading)

I would set an attribute in the Autoencoder models _group_offload_block_modules

class AutoencoderKL(...):
   _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]  

Then update apply_group_offloading to accept a block_modules argument

def apply_group_offloading(
    module: torch.nn.Module,
    onload_device: Union[str, torch.device],
    offload_device: Union[str, torch.device] = torch.device("cpu"),
    offload_type: Union[str, GroupOffloadingType] = "block_level",
    num_blocks_per_group: Optional[int] = None,
    non_blocking: bool = False,
    use_stream: bool = False,
    record_stream: bool = False,
    low_cpu_mem_usage: bool = False,
    offload_to_disk_path: Optional[str] = None,
    block_modules: Optional[list] = None # accepts list of blocks to offload 
) -> None:

Set the modules in enable_group_offload

        block_modules = block_modules or self._group_offload_block_modules if hasattr(self, "_group_offload_block_modules") else None
        apply_group_offloading(
            module=self,
            onload_device=onload_device,
            offload_device=offload_device,
            offload_type=offload_type,
            num_blocks_per_group=num_blocks_per_group,
            non_blocking=non_blocking,
            use_stream=use_stream,
            record_stream=record_stream,
            low_cpu_mem_usage=low_cpu_mem_usage,
            offload_to_disk_path=offload_to_disk_path,
            block_modules=block_modules
        )

Then in the actual group offloading step I think we can either offload the entire module if it's in _group_offload_block_modules (since AEs are typically small, this might be okay) or
recursively apply block offloading to the submodule.

    block_modules = set(config.block_modules) if config.block_modules is not None else set()
    for name, submodule in module.named_children():
        if name in block_modules:
             # offload entire submodule 
             # or recursively apply group offloading to submodule. 
            _apply_group_offloading_block_level(submodule, config)

@rycerzes
Copy link
Author

rycerzes commented Nov 24, 2025

Thanks for the detailed review and feedback @DN6 and @sayakpaul!

@DN6, I understand how the explicit opt-in approach is safer and cleaner to avoid unintended side effects on other models. I will refactor the implementation to look for a _group_offload_block_modules attribute in the model (or accept it as an argument) and use that to identify the blocks to offload, falling back to the current behavior if not present.

@sayakpaul, I will update the tests to use AutoencoderKL with a small configuration instead of the DummyVAELikeModel to ensure we are testing against the actual model structure while keeping the tests lightweight. I'm going to refactor the implementation based on @DN6's feedback to use an explicit opt-in mechanism instead of the recursive detection. This will likely change the code significantly, so some of the specific line comments might become moot, but I'll address the testing feedback (using small configs instead of dummy models, cleaning up iterations) in the new iteration.

I'm starting on these changes now.

@rycerzes rycerzes force-pushed the fix/broken-group-offloading-using-block_level branch from 664c492 to 09dd19b Compare November 24, 2025 19:56
@rycerzes rycerzes requested a review from sayakpaul November 24, 2025 19:57
@rycerzes
Copy link
Author

Hi @sayakpaul @DN6, the changes have been implemented
• Switched to explicit opt-in via _group_offload_block_modules
• Updated apply_group_offloading to support block_modules
• Updated tests to use AutoencoderKL & adjusted coverage accordingly
• Refactored per review notes and cleaned up iteration logic

Please let me know if anything else is needed; happy to iterate further!

@Cedric-Perauer
Copy link

Cedric-Perauer commented Nov 27, 2025

Hi @rycerzes, there still seems to be an issue for KLWan VAE, when applying offloading like this :

list(map(lambda module: apply_group_offloading(
    module,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    use_stream=False,
    num_blocks_per_group = 2
), [pipe.text_encoder, pipe.transformer, pipe.vae]))

it seems to fail at conv_in of the encoder.

return forward_call(*args, **kwargs)
  File "/home/cedric/Downloads/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 592, in forward
    x = self.conv_in(x, feat_cache[idx])
  File "/home/cedric/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cedric/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cedric/Downloads/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 175, in forward
    return super().forward(x)
  File "/home/cedric/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 717, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/cedric/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 712, in _conv_forward
    return F.conv3d(
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

It seems like KLWan VAE is not included in test cases as well. Can you confirm this issue ? Happy to help support this as well in next days if needed. Thank you !

@rycerzes
Copy link
Author

rycerzes commented Nov 27, 2025

Hi @Cedric-Perauer, thanks for reporting this! I was able to reproduce the issue with the Wan VAE.

The problem occurs because when you call apply_group_offloading() directly without passing block_modules, the function doesn't know which submodules should be treated as blocks. The encoder and decoder in AutoencoderKLWan are not ModuleList or Sequential types, so they end up in the "unmatched modules" group instead of being properly offloaded as blocks. This causes the encoder.conv_in layer to stay on CPU while the input is on CUDA → device mismatch error.

As @DN6 requested, I am implementing an explicit opt-in approach where models define _group_offload_block_modules to specify which submodules should be treated as blocks. AutoencoderKLWan already has this defined: _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"].

Two working solutions:

  1. Use the high-level API - it handles this automatically:
pipe.vae.enable_group_offload(
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    num_blocks_per_group=2,
)
  1. Pass block_modules explicitly when using apply_group_offloading() directly:
block_modules = getattr(pipe.vae, "_group_offload_block_modules", None)

apply_group_offloading(
    pipe.vae,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    num_blocks_per_group=2,
    block_modules=block_modules,
)

I've tested both approaches with the Wan VAE and confirmed they work. Let me know if you have any other questions!

It seems like KLWan VAE is not included in test cases as well. Can you confirm this issue ? Happy to help support this as well in next days if needed. Thank you !

KLWan VAE is not explicitly tested in the group offloading test suite, but the same code paths are tested via AutoencoderKL. Do I need to implement explicit tests for KLWan VAE? @sayakpaul @DN6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Broken group offloading using block_level

5 participants