@@ -155,8 +155,6 @@ def _pinned_memory_tensors(self):
155155
156156 def _transfer_tensor_to_device (self , tensor , source_tensor ):
157157 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
158- if self .record_stream :
159- tensor .data .record_stream (self ._torch_accelerator_module .current_stream ())
160158
161159 def _process_tensors_from_modules (self , pinned_memory = None ):
162160 for group_module in self .modules :
@@ -238,12 +236,20 @@ def _offload_to_memory(self):
238236 if not self .record_stream :
239237 self ._torch_accelerator_module .current_stream ().synchronize ()
240238
239+ current_stream = self ._torch_accelerator_module .current_stream ()
240+
241241 for group_module in self .modules :
242242 for param in group_module .parameters ():
243+ if self .record_stream and param .device .type == 'cuda' :
244+ param .data .record_stream (current_stream )
243245 param .data = self .cpu_param_dict [param ]
244246 for param in self .parameters :
247+ if self .record_stream and param .device .type == 'cuda' :
248+ param .data .record_stream (current_stream )
245249 param .data = self .cpu_param_dict [param ]
246250 for buffer in self .buffers :
251+ if self .record_stream and buffer .device .type == 'cuda' :
252+ buffer .data .record_stream (current_stream )
247253 buffer .data = self .cpu_param_dict [buffer ]
248254 else :
249255 for group_module in self .modules :
0 commit comments