From 15f98dbc7dacd774558cfe1fb27204b9d66f8341 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Jul 2025 23:50:11 +0200 Subject: [PATCH 1/6] update --- src/diffusers/hooks/group_offloading.py | 138 ++++++++++-------------- 1 file changed, 58 insertions(+), 80 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6c89101f5e98..6fd24b10724d 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -101,7 +101,7 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - if self.offload_to_disk_path: + if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) @@ -121,6 +121,12 @@ def __init__( else: self.cpu_param_dict = self._init_cpu_param_dict() + self._torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -144,16 +150,12 @@ def _init_cpu_param_dict(self): @contextmanager def _pinned_memory_tensors(self): - pinned_dict = {} try: - for param, tensor in self.cpu_param_dict.items(): - if not tensor.is_pinned(): - pinned_dict[param] = tensor.pin_memory() - else: - pinned_dict[param] = tensor - + pinned_dict = { + param: tensor.pin_memory() if not tensor.is_pinned() else tensor + for param, tensor in self.cpu_param_dict.items() + } yield pinned_dict - finally: pinned_dict = None @@ -179,77 +181,47 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, current_stream) - def _onload_from_disk(self, current_stream): + def _onload_from_disk(self): if self.stream is not None: - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - - for key, tensor_obj in self.key_to_tensor.items(): - self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] - - with self._pinned_memory_tensors() as pinned_memory: - for key, tensor_obj in self.key_to_tensor.items(): - self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) - - self.cpu_param_dict.clear() - - else: - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() - def _onload_from_memory(self, current_stream): - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - self._process_tensors_from_modules(pinned_memory, current_stream) - else: - self._process_tensors_from_modules(None, current_stream) + context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) + current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None - @torch.compiler.disable() - def onload_(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) - context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) - current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + with context: + # Load to CPU (if using streams) or directly to target device, pin, and async copy to device + device = self.onload_device if self.stream is None else "cpu" + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) - if self.offload_to_disk_path: if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - with context: - if self.stream is not None: - # Load to CPU, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_cpu_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) - else: - # Load directly to the target device (synchronous) - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] - return + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + def _onload_from_memory(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() + context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) + current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None + with context: - if self.offload_to_disk_path: - self._onload_from_disk(current_stream) + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + self._process_tensors_from_modules(pinned_memory, current_stream) else: - self._onload_from_memory(current_stream) + self._process_tensors_from_modules(None, current_stream) def _offload_to_disk(self): # TODO: we can potentially optimize this code path by checking if the _all_ the desired @@ -270,14 +242,10 @@ def _offload_to_disk(self): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) if self.stream is not None: if not self.record_stream: - torch_accelerator_module.current_stream().synchronize() + self._torch_accelerator_module.current_stream().synchronize() + for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -288,15 +256,23 @@ def _offload_to_memory(self): else: for group_module in self.modules: - group_module.to(self.offload_device, non_blocking=self.non_blocking) + group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + + @torch.compiler.disable() + def onload_(self): + r"""Onloads the group of parameters to the onload_device.""" + if self.offload_to_disk_path is not None: + self._onload_from_disk() + else: + self._onload_from_memory() @torch.compiler.disable() def offload_(self): - r"""Offloads the group of modules to the offload_device.""" + r"""Offloads the group of parameters to the offload_device.""" if self.offload_to_disk_path: self._offload_to_disk() else: @@ -462,8 +438,8 @@ def pre_forward(self, module, *args, **kwargs): def apply_group_offloading( module: torch.nn.Module, - onload_device: torch.device, - offload_device: torch.device = torch.device("cpu"), + onload_device: Union[str, torch.device], + offload_device: Union[str, torch.device] = torch.device("cpu"), offload_type: Union[str, GroupOffloadingType] = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, @@ -549,6 +525,8 @@ def apply_group_offloading( ``` """ + onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device + offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device offload_type = GroupOffloadingType(offload_type) stream = None From c6d61fa8b3f3288cda1b15f3e6d2760a1fc0bbc7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Jul 2025 02:45:01 +0200 Subject: [PATCH 2/6] update --- src/diffusers/hooks/group_offloading.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6fd24b10724d..bd67f46d2b86 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -159,27 +159,27 @@ def _pinned_memory_tensors(self): finally: pinned_dict = None - def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): + def _transfer_tensor_to_device(self, tensor, source_tensor): tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream and current_stream is not None: - tensor.data.record_stream(current_stream) + if self.record_stream: + tensor.data.record_stream(self._torch_accelerator_module.current_stream()) - def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): + def _process_tensors_from_modules(self, pinned_memory=None): for group_module in self.modules: for param in group_module.parameters(): source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) + self._transfer_tensor_to_device(param, source) for buffer in group_module.buffers(): source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) + self._transfer_tensor_to_device(buffer, source) for param in self.parameters: source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) + self._transfer_tensor_to_device(param, source) for buffer in self.buffers: source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) + self._transfer_tensor_to_device(buffer, source) def _onload_from_disk(self): if self.stream is not None: @@ -214,14 +214,12 @@ def _onload_from_memory(self): self.stream.synchronize() context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) - current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None - with context: if self.stream is not None: with self._pinned_memory_tensors() as pinned_memory: - self._process_tensors_from_modules(pinned_memory, current_stream) + self._process_tensors_from_modules(pinned_memory) else: - self._process_tensors_from_modules(None, current_stream) + self._process_tensors_from_modules(None) def _offload_to_disk(self): # TODO: we can potentially optimize this code path by checking if the _all_ the desired From 447e881bacd9da9201c47bcb4526880dd9f37b15 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 4 Aug 2025 23:14:55 +0200 Subject: [PATCH 3/6] refactor --- src/diffusers/hooks/group_offloading.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ac948bef1d5c..5d1dda46ace8 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -281,11 +281,9 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__( - self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig - ) -> None: + def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None: self.group = group - self.next_group = next_group + self.next_group: Optional[ModuleGroup] = None self.config = config def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: @@ -609,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, None, config=config) + _apply_group_offloading_hook(group_module, group, config=config) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -638,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf group_id=f"{module.__class__.__name__}_unmatched_group", ) if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, None, config=config) + _apply_group_offloading_hook(module, unmatched_group, config=config) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: @@ -669,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff onload_self=True, group_id=name, ) - _apply_group_offloading_hook(submodule, group, None, config=config) + _apply_group_offloading_hook(submodule, group, config=config) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -716,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff onload_self=True, group_id=name, ) - _apply_group_offloading_hook(parent_module, group, None, config=config) + _apply_group_offloading_hook(parent_module, group, config=config) if config.stream is not None: # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer @@ -738,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff onload_self=True, group_id=_GROUP_ID_LAZY_LEAF, ) - _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig, ) -> None: @@ -753,14 +750,13 @@ def _apply_group_offloading_hook( # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group, config=config) + hook = GroupOffloadingHook(group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig, ) -> None: @@ -769,7 +765,7 @@ def _apply_lazy_group_offloading_hook( # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group, config=config) + hook = GroupOffloadingHook(group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() From ac74eedee2757555066c3e9937abff61ff7d3f98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 6 Aug 2025 11:47:56 +0200 Subject: [PATCH 4/6] add test --- tests/hooks/test_group_offloading.py | 89 ++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 7f778be980b7..a7bec39108c8 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -17,7 +17,9 @@ import unittest import torch +from parameterized import parameterized +from diffusers.hooks import HookRegistry, ModelHook from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.utils import get_logger @@ -99,6 +101,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# Test for https://github.com/huggingface/diffusers/pull/12077 +class DummyModelWithLayerNorm(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.layer_norm(x) + x = self.linear_2(x) + return x + + class DummyPipeline(DiffusionPipeline): model_cpu_offload_seq = "model" @@ -113,6 +138,16 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: return x +class LayerOutputTrackerHook(ModelHook): + def __init__(self): + super().__init__() + self.outputs = [] + + def post_forward(self, module, output): + self.outputs.append(output) + return output + + @require_torch_accelerator class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -258,6 +293,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module def test_block_level_stream_with_invocation_order_different_from_initialization_order(self): if torch.device(torch_device).type not in ["cuda", "xpu"]: return + model = DummyModelWithMultipleBlocks( in_features=self.in_features, hidden_features=self.hidden_features, @@ -274,3 +310,56 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_ with context: model(self.input) + + @parameterized.expand([("block_level",), ("leaf_level",)]) + def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str): + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): + for name, module in model.named_modules(): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LayerOutputTrackerHook() + registry.register_hook(hook, "layer_output_tracker") + + model_ref = DummyModelWithLayerNorm(128, 256, 128, 2) + model = DummyModelWithLayerNorm(128, 256, 128, 2) + + model.load_state_dict(model_ref.state_dict(), strict=True) + + model_ref.to(torch_device) + model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) + + apply_layer_output_tracker_hook(model_ref) + apply_layer_output_tracker_hook(model) + + x = torch.randn(2, 128).to(torch_device) + + out_ref = model_ref(x) + out = model(x) + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.") + + num_repeats = 4 + for i in range(num_repeats): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.") + + for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()): + assert ref_name == name + if not isinstance(ref_module, (torch.nn.Linear, torch.nn.LayerNorm)): + continue + ref_outputs = ( + HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs + ) + outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs + cumulated_absmax = 0.0 + for i in range(len(outputs)): + diff = ref_outputs[0] - outputs[i] + absdiff = diff.abs() + absmax = absdiff.max().item() + cumulated_absmax += absmax + self.assertLess( + cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" + ) From d4f5ec03e28c182237ae6f164323dc4f646bdf16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 6 Aug 2025 11:49:13 +0200 Subject: [PATCH 5/6] address review comment --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 5d1dda46ace8..6b6871f9dc2a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -185,7 +185,7 @@ def _onload_from_disk(self): with context: # Load to CPU (if using streams) or directly to target device, pin, and async copy to device - device = self.onload_device if self.stream is None else "cpu" + device = str(self.onload_device) if self.stream is None else "cpu" loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) if self.stream is not None: From 7452ba5b2cc55c5cc33289f84180485e6d7ca959 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 6 Aug 2025 11:51:35 +0200 Subject: [PATCH 6/6] nit --- tests/hooks/test_group_offloading.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index a7bec39108c8..ea08dec19cfc 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -348,8 +348,6 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()): assert ref_name == name - if not isinstance(ref_module, (torch.nn.Linear, torch.nn.LayerNorm)): - continue ref_outputs = ( HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs )