diff --git a/.github/workflows/test_with_pytest.yml b/.github/workflows/test_with_pytest.yml index 4febc7a..e537f5b 100644 --- a/.github/workflows/test_with_pytest.yml +++ b/.github/workflows/test_with_pytest.yml @@ -1,6 +1,6 @@ name: Test using pytest -on: [push, pull_request] +on: push jobs: test: diff --git a/imas/ids_struct_array.py b/imas/ids_struct_array.py index b176864..38e8165 100644 --- a/imas/ids_struct_array.py +++ b/imas/ids_struct_array.py @@ -4,6 +4,7 @@ """ import logging +import sys from copy import deepcopy from typing import Optional, Tuple @@ -121,12 +122,25 @@ def _element_structure(self): return struct def __getitem__(self, item): - # value is a list, so the given item should be convertable to integer - # TODO: perhaps we should allow slices as well? - list_idx = int(item) - if self._lazy: - self._load(item) - return self.value[list_idx] + # allow slices + if isinstance(item, slice): + if self._lazy: + start, stop, step = item.start, item.stop, item.step + if stop is None: + stop = sys.maxsize + + for i in range(start or 0, stop, step or 1): + try: + self._load(i) + except IndexError: + break + return self.value[item] + else: + # value is a list, so the given item should be convertable to integer + list_idx = int(item) + if self._lazy: + self._load(item) + return self.value[list_idx] def __setitem__(self, item, value): # value is a list, so the given item should be convertable to integer diff --git a/imas/test/test_ids_struct_array.py b/imas/test/test_ids_struct_array.py index ab128df..8c31f22 100644 --- a/imas/test/test_ids_struct_array.py +++ b/imas/test/test_ids_struct_array.py @@ -87,3 +87,15 @@ def test_struct_array_eq(): assert cp1.profiles_1d != cp2.profiles_1d cp2.profiles_1d[0].time = 1 assert cp1.profiles_1d == cp2.profiles_1d + + +def test_struct_array_slice(): + cp1 = IDSFactory("3.39.0").core_profiles() + cp1.profiles_1d.resize(20) + + assert len(cp1.profiles_1d) == 20 + assert len(cp1.profiles_1d[:]) == 20 + assert len(cp1.profiles_1d[5:10]) == 5 + assert len(cp1.profiles_1d[10:]) == 10 + assert len(cp1.profiles_1d[:5]) == 5 + assert len(cp1.profiles_1d[::2]) == 10