Skip to content

Commit 588ff5f

Browse files
authored
Merge pull request #2784 from h-mayorquin/add_name_to_repr
Add name as an extractor attribute for `__repr__` purposes
2 parents a1958ce + 4e0f587 commit 588ff5f

File tree

6 files changed

+57
-15
lines changed

6 files changed

+57
-15
lines changed

src/spikeinterface/core/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class BaseExtractor:
4141
# This replaces the old key_properties
4242
# These are annotations/properties that always need to be
4343
# dumped (for instance locations, groups, is_fileterd, etc.)
44-
_main_annotations = []
44+
_main_annotations = ["name"]
4545
_main_properties = []
4646

4747
# these properties are skipped by default in copy_metadata
@@ -79,6 +79,19 @@ def __init__(self, main_ids: Sequence) -> None:
7979
# preferred context for multiprocessing
8080
self._preferred_mp_context = None
8181

82+
@property
83+
def name(self):
84+
name = self._annotations.get("name", None)
85+
return name if name is not None else self.__class__.__name__
86+
87+
@name.setter
88+
def name(self, value):
89+
if value is not None:
90+
self.annotate(name=value)
91+
else:
92+
# we remove the annotation if it exists
93+
_ = self._annotations.pop("name", None)
94+
8295
def get_num_segments(self) -> int:
8396
# This is implemented in BaseRecording or BaseSorting
8497
raise NotImplementedError

src/spikeinterface/core/baserecording.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets):
2323
Internally handle list of RecordingSegment
2424
"""
2525

26-
_main_annotations = ["is_filtered"]
26+
_main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"]
2727
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
2828
_main_features = [] # recording do not handle features
2929

@@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
4545
self.annotate(is_filtered=False)
4646

4747
def __repr__(self):
48-
49-
class_name = self.__class__.__name__
50-
name_to_display = class_name
5148
num_segments = self.get_num_segments()
5249

5350
txt = self._repr_header()
@@ -57,7 +54,7 @@ def __repr__(self):
5754
split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100
5855
if split_index != -1:
5956
first_line = txt[:split_index]
60-
recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": "
57+
recording_string_space = len(self.name) + 2 # Length of self.name plus ": "
6158
white_space_to_align_with_first_line = " " * recording_string_space
6259
second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip()
6360
txt = first_line + "\n" + second_line
@@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6):
9794
return txt
9895

9996
def _repr_header(self):
100-
class_name = self.__class__.__name__
101-
name_to_display = class_name
10297
num_segments = self.get_num_segments()
10398
num_channels = self.get_num_channels()
104-
sf_khz = self.get_sampling_frequency() / 1000.0
99+
sf_hz = self.get_sampling_frequency()
100+
sf_khz = sf_hz / 1000
105101
dtype = self.get_dtype()
106102

107103
total_samples = self.get_total_samples()
108104
total_duration = self.get_total_duration()
109105
total_memory_size = self.get_total_memory_size()
106+
sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"
110107

111108
txt = (
112-
f"{name_to_display}: "
109+
f"{self.name}: "
113110
f"{num_channels} channels - "
114-
f"{sf_khz:0.1f}kHz - "
111+
f"{sampling_frequency_repr} - "
115112
f"{num_segments} segments - "
116113
f"{total_samples:,} samples - "
117114
f"{convert_seconds_to_str(total_duration)} - "

src/spikeinterface/core/basesnippets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets):
1414
Abstract class representing several multichannel snippets.
1515
"""
1616

17-
_main_annotations = []
1817
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
1918
_main_features = []
2019

src/spikeinterface/core/basesorting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List):
3030
self._cached_spike_trains = {}
3131

3232
def __repr__(self):
33-
clsname = self.__class__.__name__
3433
nseg = self.get_num_segments()
3534
nunits = self.get_num_units()
3635
sf_khz = self.get_sampling_frequency() / 1000.0
37-
txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
36+
txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
3837
if "file_path" in self._kwargs:
3938
txt += "\n file_path: {}".format(self._kwargs["file_path"])
4039
return txt

src/spikeinterface/core/generate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def generate_recording(
8080
probe.set_device_channel_indices(np.arange(num_channels))
8181
recording.set_probe(probe, in_place=True)
8282

83+
recording.name = "SyntheticRecording"
84+
8385
return recording
8486

8587

@@ -2122,4 +2124,7 @@ def generate_ground_truth_recording(
21222124
recording.set_channel_gains(1.0)
21232125
recording.set_channel_offsets(0.0)
21242126

2127+
recording.name = "GroundTruthRecording"
2128+
sorting.name = "GroundTruthSorting"
2129+
21252130
return recording, sorting

src/spikeinterface/core/tests/test_base.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55

66
from typing import Sequence
7+
import numpy as np
78
from spikeinterface.core.base import BaseExtractor
8-
from spikeinterface.core import generate_recording, concatenate_recordings
9+
from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings
910

1011

1112
class DummyDictExtractor(BaseExtractor):
@@ -65,6 +66,34 @@ def test_check_if_serializable():
6566
assert not extractor.check_serializability("json")
6667

6768

69+
def test_name_and_repr():
70+
test_recording, test_sorting = generate_ground_truth_recording(seed=0, durations=[2])
71+
assert test_recording.name == "GroundTruthRecording"
72+
assert test_sorting.name == "GroundTruthSorting"
73+
74+
# set a different name
75+
test_recording.name = "MyRecording"
76+
assert test_recording.name == "MyRecording"
77+
78+
# to/from dict
79+
test_recording_dict = test_recording.to_dict()
80+
test_recording2 = BaseExtractor.from_dict(test_recording_dict)
81+
assert test_recording2.name == "MyRecording"
82+
83+
# repr
84+
rec_str = str(test_recording2)
85+
assert "MyRecording" in rec_str
86+
test_recording2.name = None
87+
assert "MyRecording" not in str(test_recording2)
88+
assert test_recording2.__class__.__name__ in str(test_recording2)
89+
# above 10khz, sampling frequency is printed in kHz
90+
assert f"kHz" in rec_str
91+
# below 10khz sampling frequency is printed in Hz
92+
test_rec_low_fs = generate_recording(seed=0, durations=[2], sampling_frequency=5000)
93+
rec_str = str(test_rec_low_fs)
94+
assert "Hz" in rec_str
95+
96+
6897
if __name__ == "__main__":
6998
test_check_if_memory_serializable()
7099
test_check_if_serializable()

0 commit comments

Comments
 (0)