Skip to content
17 changes: 15 additions & 2 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def _offload_to_memory(self):
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]

else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
Expand Down Expand Up @@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:

should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()

should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass. Without this, uninitialized
# weights will be used in the computation, leading to incorrect results
# Also, we should only do this synchronization if we don't already do it from the sync call in
# self.next_group.onload_, hence the `not should_onload_next_group` check.
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
Expand Down
Loading