Skip to content

Commit 4635aa5

Browse files
authored
(fix) CombinedDataset with more than 2 streaming datasets (#164)
1 parent 07f0483 commit 4635aa5

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

src/litdata/streaming/combined.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from copy import deepcopy
1616
from typing import Any, Dict, Iterator, List, Optional, Sequence
1717

18-
import numpy as np
1918
from torch.utils.data import IterableDataset
2019

2120
from litdata.streaming.dataset import StreamingDataset
@@ -182,14 +181,14 @@ def __init__(
182181
self,
183182
datasets: List[StreamingDataset],
184183
seed: int,
185-
weights: Sequence[float],
184+
weights: Sequence[Optional[float]],
186185
use_streaming_dataloader: bool,
187186
num_samples_yielded: Any,
188187
iterate_over_all: bool = False,
189188
) -> None:
190189
self._datasets = datasets
191190
self._dataset_iters = [iter(dataset) for dataset in datasets]
192-
self._dataset_indexes = list(range(len(datasets)))
191+
self._dataset_indexes: List[Optional[int]] = list(range(len(datasets)))
193192
self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))]
194193
self._original_weights = deepcopy(weights)
195194
self._weights = deepcopy(weights)
@@ -200,7 +199,9 @@ def __init__(
200199
if num_samples_yielded is not None:
201200
self._num_samples_yielded = num_samples_yielded
202201
for _ in range(sum(num_samples_yielded)):
203-
self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
202+
choice_indexes: List[int] = [index for index in self._dataset_indexes if index is not None]
203+
choice_weights: List[float] = [w for w in self._weights if w is not None]
204+
self._rng.choices(choice_indexes, weights=choice_weights, k=1)
204205

205206
self._use_streaming_dataloader = use_streaming_dataloader
206207
self._is_done = False
@@ -209,27 +210,31 @@ def __next__(self) -> Any:
209210
if self._iterate_over_all:
210211
while True:
211212
try:
212-
if len(self._dataset_indexes) > 1:
213+
indexes_left = [index for index in self._dataset_indexes if index is not None]
214+
if len(indexes_left) > 1:
213215
dataset_index = self._get_dataset_index()
214-
elif len(self._dataset_indexes) == 1:
215-
dataset_index = self._dataset_indexes[0]
216+
elif len(indexes_left) == 1:
217+
dataset_index = indexes_left[0]
216218
return self._get_sample(dataset_index)
217219
except StopIteration as e:
218-
if len(self._dataset_indexes) == 1:
220+
if len(indexes_left) == 1:
219221
self._dataset_indexes = list(range(len(self._datasets)))
220222
self._weights = deepcopy(self._original_weights)
221223
raise e
222224

223-
self._dataset_indexes.pop(dataset_index)
224-
self._weights.pop(dataset_index)
225-
self._weights /= np.sum(self._weights)
225+
self._dataset_indexes[dataset_index] = None
226+
self._weights[dataset_index] = None # type: ignore
227+
new_sum = sum([w for w in self._weights if w is not None])
228+
self._weights = [None if w is None else w / new_sum for w in self._weights]
226229

227230
# stop on the first iteration
228231
return self._get_sample(self._get_dataset_index())
229232

230233
def _get_dataset_index(self) -> int:
231234
# randomly select a dataset index
232-
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
235+
indexes = [index for index in self._dataset_indexes if index is not None]
236+
weights = [w for w in self._weights if w is not None]
237+
(dataset_index,) = self._rng.choices(indexes, weights=weights, k=1)
233238
return dataset_index
234239

235240
def _get_sample(self, dataset_index: int) -> Any:

tests/streaming/test_combined.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from unittest.mock import ANY, MagicMock
44

5+
import numpy as np
56
import pytest
67
import torch
78
from litdata.streaming.cache import Cache
@@ -46,6 +47,31 @@ def test_combined_dataset_num_samples_yield():
4647
assert dataset._iterator._num_samples_yielded == [2, 4]
4748

4849

50+
class Range:
51+
def __init__(self, start, end, step=1):
52+
self.values = list(range(start, end, step))
53+
54+
def set_epoch(self, epoch):
55+
self.values = np.random.RandomState(42 + epoch).permutation(self.values).tolist()
56+
57+
def __iter__(self):
58+
yield from self.values
59+
60+
61+
def test_combined_dataset_iterate_over_all_4_datasets():
62+
dataset = TestCombinedStreamingDataset(
63+
[Range(0, 10), Range(10, 20), Range(20, 30), Range(30, 40)], 42, iterate_over_all=True
64+
)
65+
data = []
66+
for i in range(2):
67+
dataset.set_epoch(i)
68+
data.append(list(dataset))
69+
70+
assert len(data[0]) == 40
71+
assert data[0][-3:] == [14, 13, 16]
72+
assert data[1][-3:] == [14, 18, 17]
73+
74+
4975
def test_combined_dataset_num_samples_yield_iterate_over_all():
5076
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True)
5177
assert len(dataset) == 20

0 commit comments

Comments
 (0)