Skip to content

Commit e98b9eb

Browse files
authored
Ensure root and group objects required are written (#306)
1 parent e1fe2ae commit e98b9eb

File tree

2 files changed

+125
-4
lines changed

2 files changed

+125
-4
lines changed

nptdms/test/writer/test_acceptance_tests.py renamed to nptdms/test/writer/test_writer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,89 @@ def test_specifying_invalid_version():
403403
error_message = str(exception.value)
404404

405405
assert "4712,4713" in error_message
406+
407+
408+
def test_root_object_added():
409+
""" When not explicitly included, a root object should be added
410+
"""
411+
group = GroupObject("group")
412+
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))
413+
414+
output_file = BytesIO()
415+
with TdmsWriter(output_file) as tdms_writer:
416+
tdms_writer.write_segment([group, channel])
417+
tdms_writer.write_segment([group, channel])
418+
419+
output_file.seek(0)
420+
421+
tdms_file = TdmsFile(output_file)
422+
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
423+
second_segment_objects = tdms_file._reader._segments[1].ordered_objects
424+
425+
assert first_segment_objects[0].path == "/"
426+
assert not any(obj.path == "/" for obj in second_segment_objects)
427+
428+
429+
def test_group_object_added():
430+
""" When not explicitly included, a group object should be added
431+
"""
432+
root = RootObject()
433+
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))
434+
435+
output_file = BytesIO()
436+
with TdmsWriter(output_file) as tdms_writer:
437+
tdms_writer.write_segment([root, channel])
438+
tdms_writer.write_segment([root, channel])
439+
440+
output_file.seek(0)
441+
442+
tdms_file = TdmsFile(output_file)
443+
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
444+
second_segment_objects = tdms_file._reader._segments[1].ordered_objects
445+
446+
assert first_segment_objects[1].path == "/'group'"
447+
assert not any(obj.path == "/'group'" for obj in second_segment_objects)
448+
449+
450+
def test_group_not_duplicated():
451+
root = RootObject()
452+
group = GroupObject("group")
453+
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))
454+
455+
output_file = BytesIO()
456+
with TdmsWriter(output_file) as tdms_writer:
457+
tdms_writer.write_segment([root, group, channel])
458+
tdms_writer.write_segment([channel])
459+
460+
output_file.seek(0)
461+
462+
tdms_file = TdmsFile(output_file)
463+
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
464+
second_segment_objects = tdms_file._reader._segments[1].ordered_objects
465+
466+
assert len(first_segment_objects) == 3
467+
assert len(second_segment_objects) == 1
468+
469+
470+
def test_root_and_groups_ordered_first():
471+
"""
472+
The root and group objects should always come first
473+
"""
474+
root = RootObject()
475+
group = GroupObject("group")
476+
channel_0 = ChannelObject("group", "b", np.linspace(0.0, 1.0, 10))
477+
channel_1 = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))
478+
479+
output_file = BytesIO()
480+
with TdmsWriter(output_file) as tdms_writer:
481+
tdms_writer.write_segment([channel_0, group, channel_1, root])
482+
483+
output_file.seek(0)
484+
485+
tdms_file = TdmsFile(output_file)
486+
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
487+
488+
assert first_segment_objects[0].path == "/"
489+
assert first_segment_objects[1].path == "/'group'"
490+
assert first_segment_objects[2].path == "/'group'/'b'"
491+
assert first_segment_objects[3].path == "/'group'/'a'"

nptdms/writer.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def __init__(self, file, mode='w', version=4712, index_file=False):
6464
It's important that if you are appending segments to an
6565
existing TDMS file, this matches the existing file version (this can be queried with the
6666
:py:attr:`~nptdms.TdmsFile.tdms_version` property).
67-
:param index_file: Whether or not to write a index file besides the data file. Index files
67+
:param index_file: Whether to write an index file besides the data file. Index files
6868
can be used to accelerate reading speeds for faster channel extraction and data positions inside
69-
the data files. If ``file```variable is a path ``index_file`` can be ``True`` to store a ``.tdms_index``
70-
file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file`` variable
71-
is a readable object ``index_file`` can either be a readable object to write into or ``False`` to omit.
69+
the data files. If ``file```variable is a path, ``index_file`` can be ``True`` to store a ``.tdms_index``
70+
file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file``
71+
is a readable object, ``index_file`` can either be a readable object to write into or ``False`` to omit.
7272
"""
7373
valid_versions = (4712, 4713)
7474
if version not in valid_versions:
@@ -79,6 +79,8 @@ def __init__(self, file, mode='w', version=4712, index_file=False):
7979
self._index_file_path = None
8080
self._file_mode = mode
8181
self._tdms_version = version
82+
self._root_written = False
83+
self._groups_written = set()
8284

8385
if hasattr(file, "read"):
8486
# Is a file
@@ -123,13 +125,37 @@ def write_segment(self, objects):
123125
124126
:param objects: A list of TdmsObject instances to write
125127
"""
128+
path_object_pairs = [(ObjectPath.from_string(o.path), o) for o in objects]
129+
130+
# Make sure a root object is included if this is the first segment,
131+
# and any groups used by channels have associated group objects
132+
add_root = (not self._root_written) and (not any(p[0].is_root for p in path_object_pairs))
133+
groups_included = set(p[0].group for p in path_object_pairs if p[0].is_group)
134+
groups_required = set(p[0].group for p in path_object_pairs if p[0].is_channel)
135+
groups_to_add = sorted(groups_required - groups_included - self._groups_written)
136+
137+
if add_root:
138+
path_object_pairs.append((ObjectPath(), RootObject()))
139+
if groups_to_add:
140+
path_object_pairs.extend((ObjectPath(g), GroupObject(g)) for g in groups_to_add)
141+
142+
# Ensure objects are ordered with root first, then groups, in case any readers depend
143+
# on parent objects being defined before their children.
144+
# Channel ordering will be unchanged as sorts are stable.
145+
path_object_pairs.sort(key=lambda p: _path_ordering_key(p[0]))
146+
147+
objects = [p[1] for p in path_object_pairs]
126148
segment = TdmsSegment(objects, version=self._tdms_version)
127149
segment.write(self._file)
128150

129151
if self._index_file is not None:
130152
segment = TdmsSegment(objects, is_index_file=True, version=self._tdms_version)
131153
segment.write(self._index_file)
132154

155+
self._root_written = True
156+
self._groups_written.update(groups_included)
157+
self._groups_written.update(groups_to_add)
158+
133159
def __enter__(self):
134160
self.open()
135161
return self
@@ -450,3 +476,12 @@ def _infer_dtype(data):
450476
else:
451477
return np.dtype('int8')
452478
return None
479+
480+
481+
def _path_ordering_key(path):
482+
if path.is_root:
483+
return 0
484+
if path.is_group:
485+
return 1
486+
if path.is_channel:
487+
return 2

0 commit comments

Comments
 (0)