1515from copy import deepcopy
1616from typing import Any , Dict , Iterator , List , Optional , Sequence
1717
18- import numpy as np
1918from torch .utils .data import IterableDataset
2019
2120from litdata .streaming .dataset import StreamingDataset
@@ -182,14 +181,14 @@ def __init__(
182181 self ,
183182 datasets : List [StreamingDataset ],
184183 seed : int ,
185- weights : Sequence [float ],
184+ weights : Sequence [Optional [ float ] ],
186185 use_streaming_dataloader : bool ,
187186 num_samples_yielded : Any ,
188187 iterate_over_all : bool = False ,
189188 ) -> None :
190189 self ._datasets = datasets
191190 self ._dataset_iters = [iter (dataset ) for dataset in datasets ]
192- self ._dataset_indexes = list (range (len (datasets )))
191+ self ._dataset_indexes : List [ Optional [ int ]] = list (range (len (datasets )))
193192 self ._num_samples_yielded = num_samples_yielded or [0 for _ in range (len (datasets ))]
194193 self ._original_weights = deepcopy (weights )
195194 self ._weights = deepcopy (weights )
@@ -200,7 +199,9 @@ def __init__(
200199 if num_samples_yielded is not None :
201200 self ._num_samples_yielded = num_samples_yielded
202201 for _ in range (sum (num_samples_yielded )):
203- self ._rng .choices (self ._dataset_indexes , weights = self ._weights , k = 1 )
202+ choice_indexes : List [int ] = [index for index in self ._dataset_indexes if index is not None ]
203+ choice_weights : List [float ] = [w for w in self ._weights if w is not None ]
204+ self ._rng .choices (choice_indexes , weights = choice_weights , k = 1 )
204205
205206 self ._use_streaming_dataloader = use_streaming_dataloader
206207 self ._is_done = False
@@ -209,27 +210,31 @@ def __next__(self) -> Any:
209210 if self ._iterate_over_all :
210211 while True :
211212 try :
212- if len (self ._dataset_indexes ) > 1 :
213+ indexes_left = [index for index in self ._dataset_indexes if index is not None ]
214+ if len (indexes_left ) > 1 :
213215 dataset_index = self ._get_dataset_index ()
214- elif len (self . _dataset_indexes ) == 1 :
215- dataset_index = self . _dataset_indexes [0 ]
216+ elif len (indexes_left ) == 1 :
217+ dataset_index = indexes_left [0 ]
216218 return self ._get_sample (dataset_index )
217219 except StopIteration as e :
218- if len (self . _dataset_indexes ) == 1 :
220+ if len (indexes_left ) == 1 :
219221 self ._dataset_indexes = list (range (len (self ._datasets )))
220222 self ._weights = deepcopy (self ._original_weights )
221223 raise e
222224
223- self ._dataset_indexes .pop (dataset_index )
224- self ._weights .pop (dataset_index )
225- self ._weights /= np .sum (self ._weights )
225+ self ._dataset_indexes [dataset_index ] = None
226+ self ._weights [dataset_index ] = None # type: ignore
227+ new_sum = sum ([w for w in self ._weights if w is not None ])
228+ self ._weights = [None if w is None else w / new_sum for w in self ._weights ]
226229
227230 # stop on the first iteration
228231 return self ._get_sample (self ._get_dataset_index ())
229232
230233 def _get_dataset_index (self ) -> int :
231234 # randomly select a dataset index
232- (dataset_index ,) = self ._rng .choices (self ._dataset_indexes , weights = self ._weights , k = 1 )
235+ indexes = [index for index in self ._dataset_indexes if index is not None ]
236+ weights = [w for w in self ._weights if w is not None ]
237+ (dataset_index ,) = self ._rng .choices (indexes , weights = weights , k = 1 )
233238 return dataset_index
234239
235240 def _get_sample (self , dataset_index : int ) -> Any :
0 commit comments