Skip to content

Commit 015f21c

Browse files
authored
Add support for iterate_over_all for the CombinedDataset (#122)
1 parent bc0366d commit 015f21c

File tree

3 files changed

+140
-35
lines changed

3 files changed

+140
-35
lines changed

src/litdata/streaming/combined.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# limitations under the License.
1313

1414
import random
15+
from copy import deepcopy
1516
from typing import Any, Dict, Iterator, List, Optional, Sequence
1617

18+
import numpy as np
1719
from torch.utils.data import IterableDataset
1820

1921
from litdata.streaming.dataset import StreamingDataset
@@ -36,15 +38,38 @@ class CombinedStreamingDataset(IterableDataset):
3638
"""
3739

3840
def __init__(
39-
self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None
41+
self,
42+
datasets: List[StreamingDataset],
43+
seed: int = 42,
44+
weights: Optional[Sequence[float]] = None,
45+
iterate_over_all: bool = True,
4046
) -> None:
47+
""" "
48+
Arguments:
49+
datasets: The list of the StreamingDataset to use.
50+
seed: The random seed to initialize the sampler
51+
weights: The sampling ratio for the datasets
52+
iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets.
53+
Otherwise, it stops as soon as one raises a StopIteration.
54+
"""
55+
4156
self._check_datasets(datasets)
4257

4358
self._seed = seed
4459
self._datasets = datasets
4560
self._weights = weights
61+
self._iterate_over_all = iterate_over_all
62+
4663
num_datasets = len(datasets)
4764

65+
if iterate_over_all and weights:
66+
raise ValueError(
67+
"When `iterate_over_all` is set to True, the weights argument shouldn't be provided.",
68+
" Instead, it will be computed from the inverse of the dataset length.",
69+
)
70+
71+
self._iterate_over_all = iterate_over_all
72+
4873
if weights is None:
4974
# Inversely weighted based on length
5075
self._weights = [1 / float(num_datasets)] * num_datasets
@@ -56,6 +81,15 @@ def __init__(
5681
self._num_samples_yielded: Optional[List[int]] = None
5782
self._current_epoch = 0
5883

84+
def __len__(self) -> Optional[int]:
85+
if self._iterate_over_all:
86+
return self._get_total_length()
87+
return None
88+
89+
# total length of the datasets
90+
def _get_total_length(self) -> int:
91+
return sum(len(d) for d in self._datasets)
92+
5993
def set_epoch(self, current_epoch: int) -> None:
6094
"""Set the current epoch to the datasets on epoch starts.
6195
@@ -95,6 +129,7 @@ def __iter__(self) -> Iterator[Any]:
95129
self._weights,
96130
self._use_streaming_dataloader,
97131
num_samples_yielded,
132+
self._iterate_over_all,
98133
)
99134
return self._iterator
100135

@@ -132,31 +167,61 @@ def __init__(
132167
seed: int,
133168
weights: Sequence[float],
134169
use_streaming_dataloader: bool,
135-
num_samples_yielded: Optional[Any] = None,
170+
num_samples_yielded: Any,
171+
iterate_over_all: bool = False,
136172
) -> None:
137173
self._datasets = datasets
138174
self._dataset_iters = [iter(dataset) for dataset in datasets]
139175
self._dataset_indexes = list(range(len(datasets)))
140-
self._num_samples_yielded = [0 for _ in range(len(datasets))]
141-
self._weights = weights
176+
self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))]
177+
self._original_weights = deepcopy(weights)
178+
self._weights = deepcopy(weights)
142179
self._rng = random.Random(seed)
180+
self._iterate_over_all = iterate_over_all
181+
self._is_done = False
143182

144183
if num_samples_yielded is not None:
145184
self._num_samples_yielded = num_samples_yielded
146185
for _ in range(sum(num_samples_yielded)):
147186
self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
148187

149188
self._use_streaming_dataloader = use_streaming_dataloader
189+
self._is_done = False
150190

151191
def __next__(self) -> Any:
192+
if self._iterate_over_all:
193+
while True:
194+
try:
195+
if len(self._dataset_indexes) > 1:
196+
dataset_index = self._get_dataset_index()
197+
elif len(self._dataset_indexes) == 1:
198+
dataset_index = self._dataset_indexes[0]
199+
return self._get_sample(dataset_index)
200+
except StopIteration as e:
201+
if len(self._dataset_indexes) == 1:
202+
self._dataset_indexes = list(range(len(self._datasets)))
203+
self._weights = deepcopy(self._original_weights)
204+
raise e
205+
206+
self._dataset_indexes.pop(dataset_index)
207+
self._weights.pop(dataset_index)
208+
self._weights /= np.sum(self._weights)
209+
210+
# stop on the first iteration
211+
return self._get_sample(self._get_dataset_index())
212+
213+
def _get_dataset_index(self) -> int:
152214
# randomly select a dataset index
153215
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
216+
return dataset_index
217+
218+
def _get_sample(self, dataset_index: int) -> Any:
219+
# get the sample
220+
sample = next(self._dataset_iters[dataset_index])
154221

155222
# keep track the sample was fetched
156223
self._num_samples_yielded[dataset_index] += 1
157224

158-
sample = next(self._dataset_iters[dataset_index])
159-
160225
# return a new sample
161226
if self._use_streaming_dataloader:
162227
return {

tests/streaming/test_combined.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,25 @@ def _check_datasets(self, datasets) -> None:
1818

1919

2020
def test_combined_dataset_num_samples_yield():
21-
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5))
21+
dataset = TestCombinedStreamingDataset(
22+
[range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False
23+
)
2224
dataset_iter = iter(dataset)
2325

2426
data = list(dataset_iter)
2527
assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8]
2628

27-
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5))
29+
dataset = TestCombinedStreamingDataset(
30+
[range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5), iterate_over_all=False
31+
)
2832
dataset_iter = iter(dataset)
2933

3034
data = list(dataset_iter)
3135
assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5]
3236

33-
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5))
37+
dataset = TestCombinedStreamingDataset(
38+
[range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5), iterate_over_all=False
39+
)
3440
dataset_iter = iter(dataset)
3541

3642
data = [next(dataset_iter) for _ in range(5)]
@@ -40,6 +46,13 @@ def test_combined_dataset_num_samples_yield():
4046
assert dataset._iterator._num_samples_yielded == [2, 4]
4147

4248

49+
def test_combined_dataset_num_samples_yield_iterate_over_all():
50+
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True)
51+
assert len(dataset) == 20
52+
samples = list(dataset)
53+
assert len(samples) == 20
54+
55+
4356
class TestStatefulDataset:
4457
def __init__(self, size, step):
4558
self.size = size
@@ -69,14 +82,20 @@ def load_state_dict(self, state_dict):
6982

7083
def test_combined_dataset_state_dict():
7184
dataset = TestCombinedStreamingDataset(
72-
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
85+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
86+
42,
87+
weights=(0.5, 0.5),
88+
iterate_over_all=False,
7389
)
7490
assert dataset.state_dict(0, 1) == {}
7591
dataset_iter = iter(dataset)
7692
assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}}
7793

7894
dataset2 = TestCombinedStreamingDataset(
79-
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
95+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
96+
42,
97+
weights=(0.5, 0.5),
98+
iterate_over_all=False,
8099
)
81100
assert dataset2.state_dict(0, 1) == {}
82101

@@ -111,7 +130,10 @@ def test_combined_dataset_state_dict():
111130
]
112131

113132
dataset2 = TestCombinedStreamingDataset(
114-
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
133+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
134+
42,
135+
weights=(0.5, 0.5),
136+
iterate_over_all=False,
115137
)
116138
assert dataset2.state_dict(0, 1) == {}
117139
dataset2_iter = iter(dataset2)
@@ -136,7 +158,7 @@ def test_combined_dataset_state_dict():
136158
],
137159
)
138160
def test_combined_dataset_normalizes_weights(weights, expected):
139-
combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1)
161+
combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, iterate_over_all=False, seed=1)
140162
assert combined_dataset._weights == expected
141163

142164

@@ -159,21 +181,27 @@ def set_epoch(self, current_epoch):
159181
def test_combined_dataset():
160182
dataset1 = SimpleDataset(0, 10)
161183
dataset2 = SimpleDataset(10, 20)
162-
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345)
184+
dataset = TestCombinedStreamingDataset(
185+
datasets=[dataset1, dataset2], weights=[1.0, 0.0], iterate_over_all=False, seed=12345
186+
)
163187

164188
res = list(dataset)
165189
assert res == list(range(0, 10))
166190

167191
dataset1 = SimpleDataset(0, 10)
168192
dataset2 = SimpleDataset(10, 20)
169-
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345)
193+
dataset = TestCombinedStreamingDataset(
194+
datasets=[dataset1, dataset2], weights=[0.0, 1.0], iterate_over_all=False, seed=12345
195+
)
170196

171197
res = list(dataset)
172198
assert res == list(range(10, 20))
173199

174200
dataset1 = SimpleDataset(0, 10)
175201
dataset2 = SimpleDataset(10, 20)
176-
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
202+
dataset = TestCombinedStreamingDataset(
203+
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
204+
)
177205

178206
res = list(dataset)
179207
assert 9 in res or 19 in res
@@ -183,7 +211,9 @@ def test_combined_dataset():
183211

184212
dataset1 = SimpleDataset(0, 10)
185213
dataset2 = SimpleDataset(10, 20)
186-
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
214+
dataset = TestCombinedStreamingDataset(
215+
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
216+
)
187217
dataloader = DataLoader(dataset, batch_size=2, num_workers=1)
188218
dataloader_iter = iter(dataloader)
189219
assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1]))
@@ -193,7 +223,9 @@ def test_combined_dataset():
193223
def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
194224
dataset1 = SimpleDataset(0, 10)
195225
dataset2 = SimpleDataset(10, 20)
196-
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
226+
dataset = TestCombinedStreamingDataset(
227+
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
228+
)
197229
dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1)
198230
dataloader_iter = iter(dataloader)
199231

@@ -260,7 +292,9 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
260292

261293
dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
262294
dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
263-
dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
295+
dataset = CombinedStreamingDataset(
296+
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
297+
)
264298
dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2)
265299

266300
assert dataset1.current_epoch == 1
@@ -454,7 +488,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
454488
{
455489
"dataset": {
456490
"0": {
457-
"num_samples_yielded": 9,
491+
"num_samples_yielded": 8,
458492
"num_workers": 3,
459493
"batch_size": 2,
460494
"current_epoch": 1,
@@ -482,12 +516,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
482516
},
483517
"current_epoch": 0,
484518
"latest_worker_idx": 2,
485-
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
519+
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]},
486520
},
487521
{
488522
"dataset": {
489523
"0": {
490-
"num_samples_yielded": 11,
524+
"num_samples_yielded": 9,
491525
"num_workers": 3,
492526
"batch_size": 2,
493527
"current_epoch": 1,
@@ -515,12 +549,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
515549
},
516550
"current_epoch": 0,
517551
"latest_worker_idx": 0,
518-
"num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
552+
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]},
519553
},
520554
{
521555
"dataset": {
522556
"0": {
523-
"num_samples_yielded": 13,
557+
"num_samples_yielded": 10,
524558
"num_workers": 3,
525559
"batch_size": 2,
526560
"current_epoch": 1,
@@ -548,7 +582,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
548582
},
549583
"current_epoch": 0,
550584
"latest_worker_idx": 1,
551-
"num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
585+
"num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]},
552586
},
553587
]
554588

@@ -721,7 +755,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
721755
{
722756
"dataset": {
723757
"0": {
724-
"num_samples_yielded": 9,
758+
"num_samples_yielded": 8,
725759
"num_workers": 3,
726760
"batch_size": 2,
727761
"current_epoch": 2,
@@ -749,12 +783,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
749783
},
750784
"current_epoch": 1,
751785
"latest_worker_idx": 2,
752-
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
786+
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]},
753787
},
754788
{
755789
"dataset": {
756790
"0": {
757-
"num_samples_yielded": 11,
791+
"num_samples_yielded": 9,
758792
"num_workers": 3,
759793
"batch_size": 2,
760794
"current_epoch": 2,
@@ -782,12 +816,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
782816
},
783817
"current_epoch": 1,
784818
"latest_worker_idx": 0,
785-
"num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
819+
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]},
786820
},
787821
{
788822
"dataset": {
789823
"0": {
790-
"num_samples_yielded": 13,
824+
"num_samples_yielded": 10,
791825
"num_workers": 3,
792826
"batch_size": 2,
793827
"current_epoch": 2,
@@ -815,7 +849,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
815849
},
816850
"current_epoch": 1,
817851
"latest_worker_idx": 1,
818-
"num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
852+
"num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]},
819853
},
820854
]
821855

0 commit comments

Comments
 (0)