Skip to content

Commit e62cc32

Browse files
committed
WIP for merging AP/LFP channels
1 parent 4d93f4d commit e62cc32

File tree

2 files changed

+38
-35
lines changed

2 files changed

+38
-35
lines changed

spikeinterface/preprocessing/merge_ap_lfp.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,33 @@ def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_seg
5858
self.lfp_filter = lfp_filter
5959
self.margin = margin
6060

61+
self.AP_TO_LFP = int(round(ap_recording_segment.sampling_frequency / lfp_recording_segment.sampling_frequency))
62+
6163

6264
def get_num_samples(self) -> int:
63-
return self.ap_recording.get_num_samples()
65+
# Trunk the recording to have a number of samples that is a multiple of 'AP_TO_LFP'.
66+
return self.ap_recording.get_num_samples() - (self.ap_recording.get_num_samples() % self.AP_TO_LFP)
6467

6568

6669
def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None,
6770
channel_indices: Union[List, None] = None) -> np.ndarray:
68-
AP_TO_LFP = int(round(self.ap_recording.sampling_frequency / self.lfp_recording.sampling_frequency))
6971
if start_frame is None:
7072
start_frame = 0
7173
if end_frame is None:
7274
end_frame = self.get_num_samples()
7375

74-
assert end_frame % AP_TO_LFP == 0 # Fix this.
75-
76-
ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + AP_TO_LFP)
76+
ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + self.AP_TO_LFP)
7777

78-
left_leftover = (AP_TO_LFP - (start_frame - left_margin) % AP_TO_LFP) % AP_TO_LFP
78+
left_leftover = (self.AP_TO_LFP - (start_frame - left_margin) % self.AP_TO_LFP) % self.AP_TO_LFP
7979
left_margin -= left_leftover
80+
right_leftover = (end_frame + right_margin) % self.AP_TO_LFP
81+
right_margin -= right_leftover
8082

83+
if right_leftover > 0:
84+
ap_traces = ap_traces[:right_leftover]
8185
ap_traces = ap_traces[left_leftover:]
8286

83-
lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // AP_TO_LFP, (end_frame + right_margin) // AP_TO_LFP, channel_indices)
87+
lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices)
8488

8589
ap_fourier = np.fft.rfft(ap_traces, axis=0)
8690
lfp_fourier = np.fft.rfft(lfp_traces, axis=0)
@@ -96,12 +100,11 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int,
96100
reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None]
97101

98102
# Compute aliasing of high frequencies on LFP channels
99-
# TODO: There may be a faster way than computing the Fourier transform
100103
lfp_nyquist = self.lfp_recording.sampling_frequency / 2
101104
fourier_aliased = reconstructed_ap_fourier.copy()
102105
fourier_aliased[ap_freq <= lfp_nyquist] = 0.0
103106
fourier_aliased *= self.lfp_filter(ap_freq)[:, None]
104-
traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::AP_TO_LFP]
107+
traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::self.AP_TO_LFP]
105108
fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None]
106109
fourier_aliased = fourier_aliased[:np.searchsorted(ap_freq, lfp_nyquist, side="right")]
107110
lfp_aa_fourier = reconstructed_lfp_fourier - fourier_aliased
@@ -116,7 +119,7 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int,
116119
fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128)
117120
idx = np.searchsorted(ap_freq, lfp_nyquist, side="right")
118121
fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:]
119-
fourier_reconstructed[:idx] = AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx])
122+
fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx])
120123

121124
# To get back to the 0.5 - 10,000 Hz original filter
122125
# filter_reconstructed = generate_RC_filter(ap_freq, [0.5, 10000])[:, None]

spikeinterface/preprocessing/tests/test_merge_ap_lfp.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,33 @@ def test_MergeApLfpRecording():
6767

6868
assert np.all(np.abs(merged_traces - chunked_traces)[500:-500] < 0.05)
6969

70-
import plotly.graph_objects as go
71-
fig = go.Figure()
72-
73-
fig.add_trace(go.Scatter(
74-
x=np.arange(sf*T),
75-
y=merged_traces[:, 0],
76-
mode="lines",
77-
name="Non-chunked"
78-
))
79-
fig.add_trace(go.Scatter(
80-
x=np.arange(sf*T),
81-
y=chunked_traces[:, 0],
82-
mode="lines",
83-
name="Chunked"
84-
))
85-
fig.add_trace(go.Scatter(
86-
x=np.arange(sf*T),
87-
y=merged_traces[:, 0] - chunked_traces[:, 0],
88-
mode="lines",
89-
name="Difference"
90-
))
91-
92-
for i in range(1, T):
93-
fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)")
70+
# import plotly.graph_objects as go
71+
# fig = go.Figure()
72+
73+
# fig.add_trace(go.Scatter(
74+
# x=np.arange(sf*T),
75+
# y=merged_traces[:, 0],
76+
# mode="lines",
77+
# name="Non-chunked"
78+
# ))
79+
# fig.add_trace(go.Scatter(
80+
# x=np.arange(sf*T),
81+
# y=chunked_traces[:, 0],
82+
# mode="lines",
83+
# name="Chunked"
84+
# ))
85+
# fig.add_trace(go.Scatter(
86+
# x=np.arange(sf*T),
87+
# y=merged_traces[:, 0] - chunked_traces[:, 0],
88+
# mode="lines",
89+
# name="Difference"
90+
# ))
91+
92+
# for i in range(1, T):
93+
# fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)")
9494

9595
# fig.update_xaxes(type="log")
96-
fig.show()
96+
# fig.show()
9797

9898

9999
if __name__ == '__main__':

0 commit comments

Comments
 (0)