diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6c89101f5e98..bd67f46d2b86 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -101,7 +101,7 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - if self.offload_to_disk_path: + if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) @@ -121,6 +121,12 @@ def __init__( else: self.cpu_param_dict = self._init_cpu_param_dict() + self._torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -144,112 +150,76 @@ def _init_cpu_param_dict(self): @contextmanager def _pinned_memory_tensors(self): - pinned_dict = {} try: - for param, tensor in self.cpu_param_dict.items(): - if not tensor.is_pinned(): - pinned_dict[param] = tensor.pin_memory() - else: - pinned_dict[param] = tensor - + pinned_dict = { + param: tensor.pin_memory() if not tensor.is_pinned() else tensor + for param, tensor in self.cpu_param_dict.items() + } yield pinned_dict - finally: pinned_dict = None - def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): + def _transfer_tensor_to_device(self, tensor, source_tensor): tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream and current_stream is not None: - tensor.data.record_stream(current_stream) + if self.record_stream: + tensor.data.record_stream(self._torch_accelerator_module.current_stream()) - def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): + def _process_tensors_from_modules(self, pinned_memory=None): for group_module in self.modules: for param in group_module.parameters(): source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) + self._transfer_tensor_to_device(param, source) for buffer in group_module.buffers(): source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) + self._transfer_tensor_to_device(buffer, source) for param in self.parameters: source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) + self._transfer_tensor_to_device(param, source) for buffer in self.buffers: source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) + self._transfer_tensor_to_device(buffer, source) - def _onload_from_disk(self, current_stream): + def _onload_from_disk(self): if self.stream is not None: - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - - for key, tensor_obj in self.key_to_tensor.items(): - self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] - - with self._pinned_memory_tensors() as pinned_memory: - for key, tensor_obj in self.key_to_tensor.items(): - self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) - - self.cpu_param_dict.clear() - - else: - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() - def _onload_from_memory(self, current_stream): - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - self._process_tensors_from_modules(pinned_memory, current_stream) - else: - self._process_tensors_from_modules(None, current_stream) + context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) + current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None - @torch.compiler.disable() - def onload_(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) - context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) - current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + with context: + # Load to CPU (if using streams) or directly to target device, pin, and async copy to device + device = self.onload_device if self.stream is None else "cpu" + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) - if self.offload_to_disk_path: if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - with context: - if self.stream is not None: - # Load to CPU, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_cpu_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) - else: - # Load directly to the target device (synchronous) - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] - return + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + def _onload_from_memory(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() + context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) with context: - if self.offload_to_disk_path: - self._onload_from_disk(current_stream) + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + self._process_tensors_from_modules(pinned_memory) else: - self._onload_from_memory(current_stream) + self._process_tensors_from_modules(None) def _offload_to_disk(self): # TODO: we can potentially optimize this code path by checking if the _all_ the desired @@ -270,14 +240,10 @@ def _offload_to_disk(self): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) if self.stream is not None: if not self.record_stream: - torch_accelerator_module.current_stream().synchronize() + self._torch_accelerator_module.current_stream().synchronize() + for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -288,15 +254,23 @@ def _offload_to_memory(self): else: for group_module in self.modules: - group_module.to(self.offload_device, non_blocking=self.non_blocking) + group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + + @torch.compiler.disable() + def onload_(self): + r"""Onloads the group of parameters to the onload_device.""" + if self.offload_to_disk_path is not None: + self._onload_from_disk() + else: + self._onload_from_memory() @torch.compiler.disable() def offload_(self): - r"""Offloads the group of modules to the offload_device.""" + r"""Offloads the group of parameters to the offload_device.""" if self.offload_to_disk_path: self._offload_to_disk() else: @@ -462,8 +436,8 @@ def pre_forward(self, module, *args, **kwargs): def apply_group_offloading( module: torch.nn.Module, - onload_device: torch.device, - offload_device: torch.device = torch.device("cpu"), + onload_device: Union[str, torch.device], + offload_device: Union[str, torch.device] = torch.device("cpu"), offload_type: Union[str, GroupOffloadingType] = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, @@ -549,6 +523,8 @@ def apply_group_offloading( ``` """ + onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device + offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device offload_type = GroupOffloadingType(offload_type) stream = None