Skip to content

[refactor] condense group offloading #11990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 64 additions & 88 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False

if self.offload_to_disk_path:
if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
Expand All @@ -121,6 +121,12 @@ def __init__(
else:
self.cpu_param_dict = self._init_cpu_param_dict()

self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand All @@ -144,112 +150,76 @@ def _init_cpu_param_dict(self):

@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor

pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict

finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())

def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)

for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)

for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)

def _onload_from_disk(self, current_stream):
def _onload_from_disk(self):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")

for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]

with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)

self.cpu_param_dict.clear()

else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None

@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
with context:
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = self.onload_device if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)

if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]

def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._onload_from_memory(current_stream)
self._process_tensors_from_modules(None)

def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
Expand All @@ -270,14 +240,10 @@ def _offload_to_disk(self):
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
self._torch_accelerator_module.current_stream().synchronize()

for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
Expand All @@ -288,15 +254,23 @@ def _offload_to_memory(self):

else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()

@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
Expand Down Expand Up @@ -462,8 +436,8 @@ def pre_forward(self, module, *args, **kwargs):

def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
Expand Down Expand Up @@ -549,6 +523,8 @@ def apply_group_offloading(
```
"""

onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)

stream = None
Expand Down
Loading