Skip to content

Commit 7e9b039

Browse files
committed
Fixing after rebase.
1 parent 4bdc45c commit 7e9b039

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

src/spikeinterface/core/motion.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ def __repr__(self):
6767
else:
6868
rigid_txt = f"non-rigid - {nbins} spatial bins"
6969

70-
interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0]
70+
if self.temporal_bins_s[0].size > 1:
71+
interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0]
72+
else:
73+
# If there is only one temporal bin (entire session), assume the bin
74+
# left edge is zero, and take twice it for the bin size.
75+
interval_s = self.temporal_bins_s[0][0] * 2
76+
7177
txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments"
7278
return txt
7379

@@ -149,6 +155,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
149155
# reshape to grid domain shape if necessary
150156
displacement = displacement.reshape(out_shape)
151157

158+
# TODO: hacky
159+
if self.temporal_bins_s[segment_index].size == 1 and self.spatial_bins_um.size == 1:
160+
assert np.all(np.isnan(displacement))
161+
assert self.displacement[segment_index].size == 1
162+
displacement[:] = self.displacement[segment_index]
163+
152164
return displacement
153165

154166
def to_dict(self):

src/spikeinterface/sortingcomponents/motion/motion_interpolation.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,10 @@ def __init__(
419419
interpolation_time_bin_centers_s = motion.temporal_bins_s
420420
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s
421421
else:
422-
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = (
423-
ensure_time_bins( # TODO: something is very wrong here with the typing
424-
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
425-
)
422+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
423+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
426424
)
427425

428-
assert len(recording._recording_segments) == 1, "multi segment not supported" # ??
429-
430426
for segment_index, parent_segment in enumerate(recording._recording_segments):
431427
# finish the per-segment part of the time bin logic
432428
if interpolation_time_bin_centers_s is None:
@@ -440,8 +436,8 @@ def __init__(
440436
)
441437
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,)
442438
else:
443-
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s # [segment_index]
444-
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s # [segment_index]
439+
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index]
440+
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index]
445441

446442
rec_segment = InterpolateMotionRecordingSegment(
447443
parent_segment,

src/spikeinterface/sortingcomponents/motion/motion_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def get_spatial_windows(
8686
window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border
8787
windows = []
8888

89-
print("CENTERS: ", window_centers.size)
90-
9189
for win_center in window_centers:
9290
if win_shape == "gaussian":
9391
win = np.exp(-((spatial_bin_centers - win_center) ** 2) / (2 * win_scale_um**2))
@@ -253,7 +251,7 @@ def make_2d_motion_histogram(
253251
arr[:, 1] = peak_locations[direction]
254252

255253
if weight_with_amplitude:
256-
weights = np.abs(peaks["amplitude"]) * 10
254+
weights = np.abs(peaks["amplitude"])
257255
else:
258256
weights = None
259257

0 commit comments

Comments
 (0)