Skip to content
14 changes: 9 additions & 5 deletions caiman/motion_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig
strides=(96, 96), overlaps=(32, 32), splits_els=14, num_splits_to_process_els=None,
upsample_factor_grid=4, max_deviation_rigid=3, shifts_opencv=True, nonneg_movie=True, gSig_filt=None,
use_cuda=False, border_nan=True, pw_rigid=False, num_frames_split=80, var_name_hdf5='mov',is3D=False,
indices=(slice(None), slice(None))):
indices=(slice(None), slice(None)), subidx=slice(None, None, 1)):
"""
Constructor class for motion correction operations

Expand Down Expand Up @@ -206,6 +206,7 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig
self.var_name_hdf5 = var_name_hdf5
self.is3D = bool(is3D)
self.indices = indices
self.subidx = subidx
if self.use_cuda and not HAS_CUDA:
logging.debug("pycuda is unavailable. Falling back to default FFT.")

Expand Down Expand Up @@ -310,7 +311,8 @@ def motion_correct_rigid(self, template=None, save_movie=False) -> None:
border_nan=self.border_nan,
var_name_hdf5=self.var_name_hdf5,
is3D=self.is3D,
indices=self.indices)
indices=self.indices,
subidx=self.subidx)
if template is None:
self.total_template_rig = _total_template_rig

Expand Down Expand Up @@ -2816,7 +2818,9 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl
Exception 'The movie contains nans. Nans are not allowed!'

"""

if subidx.start is not None:
print(f"Computing template from frames {subidx}.")

dims, T = cm.source_extraction.cnmf.utilities.get_file_size(fname, var_name_hdf5=var_name_hdf5)
Ts = np.arange(T)[subidx].shape[0]
step = Ts // 10 if is3D else Ts // 50
Expand Down Expand Up @@ -2878,10 +2882,10 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl

fname_tot_rig, res_rig = motion_correction_piecewise(fname, splits, strides=None, overlaps=None,
add_to_movie=add_to_movie, template=old_templ, max_shifts=max_shifts, max_deviation_rigid=0,
dview=dview, save_movie=save_movie, base_name=base_name, subidx = subidx,
dview=dview, save_movie=save_movie, base_name=base_name,
num_splits=num_splits_to_process, shifts_opencv=shifts_opencv, nonneg_movie=nonneg_movie, gSig_filt=gSig_filt,
use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5, is3D=is3D,
indices=indices)
indices=indices, subidx=None)
if is3D:
new_templ = np.nanmedian(np.stack([r[-1] for r in res_rig]), 0)
else:
Expand Down