@@ -18,19 +18,25 @@ def _check_datasets(self, datasets) -> None:
1818
1919
2020def test_combined_dataset_num_samples_yield ():
21- dataset = TestCombinedStreamingDataset ([range (10 ), range (0 , - 10 , - 1 )], 42 , weights = (0.5 , 0.5 ))
21+ dataset = TestCombinedStreamingDataset (
22+ [range (10 ), range (0 , - 10 , - 1 )], 42 , weights = (0.5 , 0.5 ), iterate_over_all = False
23+ )
2224 dataset_iter = iter (dataset )
2325
2426 data = list (dataset_iter )
2527 assert data == [0 , 0 , 1 , 2 , - 1 , - 2 , - 3 , 3 , 4 , 5 , 6 , - 4 , 7 , 8 , - 5 , - 6 , 9 , - 7 , - 8 ]
2628
27- dataset = TestCombinedStreamingDataset ([range (10 ), range (0 , - 10 , - 1 )], 37 , weights = (0.5 , 0.5 ))
29+ dataset = TestCombinedStreamingDataset (
30+ [range (10 ), range (0 , - 10 , - 1 )], 37 , weights = (0.5 , 0.5 ), iterate_over_all = False
31+ )
2832 dataset_iter = iter (dataset )
2933
3034 data = list (dataset_iter )
3135 assert data == [0 , 0 , - 1 , - 2 , - 3 , - 4 , - 5 , 1 , - 6 , 2 , - 7 , - 8 , 3 , 4 , - 9 , 5 ]
3236
33- dataset = TestCombinedStreamingDataset ([range (10 ), range (0 , - 10 , - 1 )], 23 , weights = (0.5 , 0.5 ))
37+ dataset = TestCombinedStreamingDataset (
38+ [range (10 ), range (0 , - 10 , - 1 )], 23 , weights = (0.5 , 0.5 ), iterate_over_all = False
39+ )
3440 dataset_iter = iter (dataset )
3541
3642 data = [next (dataset_iter ) for _ in range (5 )]
@@ -40,6 +46,13 @@ def test_combined_dataset_num_samples_yield():
4046 assert dataset ._iterator ._num_samples_yielded == [2 , 4 ]
4147
4248
49+ def test_combined_dataset_num_samples_yield_iterate_over_all ():
50+ dataset = TestCombinedStreamingDataset ([range (10 ), range (0 , - 10 , - 1 )], 42 , iterate_over_all = True )
51+ assert len (dataset ) == 20
52+ samples = list (dataset )
53+ assert len (samples ) == 20
54+
55+
4356class TestStatefulDataset :
4457 def __init__ (self , size , step ):
4558 self .size = size
@@ -69,14 +82,20 @@ def load_state_dict(self, state_dict):
6982
7083def test_combined_dataset_state_dict ():
7184 dataset = TestCombinedStreamingDataset (
72- [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )], 42 , weights = (0.5 , 0.5 )
85+ [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )],
86+ 42 ,
87+ weights = (0.5 , 0.5 ),
88+ iterate_over_all = False ,
7389 )
7490 assert dataset .state_dict (0 , 1 ) == {}
7591 dataset_iter = iter (dataset )
7692 assert dataset .state_dict (0 , 1 ) == {"0" : {"counter" : 0 }, "1" : {"counter" : 0 }}
7793
7894 dataset2 = TestCombinedStreamingDataset (
79- [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )], 42 , weights = (0.5 , 0.5 )
95+ [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )],
96+ 42 ,
97+ weights = (0.5 , 0.5 ),
98+ iterate_over_all = False ,
8099 )
81100 assert dataset2 .state_dict (0 , 1 ) == {}
82101
@@ -111,7 +130,10 @@ def test_combined_dataset_state_dict():
111130 ]
112131
113132 dataset2 = TestCombinedStreamingDataset (
114- [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )], 42 , weights = (0.5 , 0.5 )
133+ [TestStatefulDataset (10 , 1 ), TestStatefulDataset (10 , - 1 )],
134+ 42 ,
135+ weights = (0.5 , 0.5 ),
136+ iterate_over_all = False ,
115137 )
116138 assert dataset2 .state_dict (0 , 1 ) == {}
117139 dataset2_iter = iter (dataset2 )
@@ -136,7 +158,7 @@ def test_combined_dataset_state_dict():
136158 ],
137159)
138160def test_combined_dataset_normalizes_weights (weights , expected ):
139- combined_dataset = TestCombinedStreamingDataset ([[1 ], [2 , 3 ]], weights = weights , seed = 1 )
161+ combined_dataset = TestCombinedStreamingDataset ([[1 ], [2 , 3 ]], weights = weights , iterate_over_all = False , seed = 1 )
140162 assert combined_dataset ._weights == expected
141163
142164
@@ -159,21 +181,27 @@ def set_epoch(self, current_epoch):
159181def test_combined_dataset ():
160182 dataset1 = SimpleDataset (0 , 10 )
161183 dataset2 = SimpleDataset (10 , 20 )
162- dataset = TestCombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [1.0 , 0.0 ], seed = 12345 )
184+ dataset = TestCombinedStreamingDataset (
185+ datasets = [dataset1 , dataset2 ], weights = [1.0 , 0.0 ], iterate_over_all = False , seed = 12345
186+ )
163187
164188 res = list (dataset )
165189 assert res == list (range (0 , 10 ))
166190
167191 dataset1 = SimpleDataset (0 , 10 )
168192 dataset2 = SimpleDataset (10 , 20 )
169- dataset = TestCombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [0.0 , 1.0 ], seed = 12345 )
193+ dataset = TestCombinedStreamingDataset (
194+ datasets = [dataset1 , dataset2 ], weights = [0.0 , 1.0 ], iterate_over_all = False , seed = 12345
195+ )
170196
171197 res = list (dataset )
172198 assert res == list (range (10 , 20 ))
173199
174200 dataset1 = SimpleDataset (0 , 10 )
175201 dataset2 = SimpleDataset (10 , 20 )
176- dataset = TestCombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], seed = 12345 )
202+ dataset = TestCombinedStreamingDataset (
203+ datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], iterate_over_all = False , seed = 12345
204+ )
177205
178206 res = list (dataset )
179207 assert 9 in res or 19 in res
@@ -183,7 +211,9 @@ def test_combined_dataset():
183211
184212 dataset1 = SimpleDataset (0 , 10 )
185213 dataset2 = SimpleDataset (10 , 20 )
186- dataset = TestCombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], seed = 12345 )
214+ dataset = TestCombinedStreamingDataset (
215+ datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], iterate_over_all = False , seed = 12345
216+ )
187217 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 1 )
188218 dataloader_iter = iter (dataloader )
189219 assert torch .equal (next (dataloader_iter ), torch .Tensor ([0 , 1 ]))
@@ -193,7 +223,9 @@ def test_combined_dataset():
193223def test_combined_dataset_with_dataloader_and_one_worker (batch_size ):
194224 dataset1 = SimpleDataset (0 , 10 )
195225 dataset2 = SimpleDataset (10 , 20 )
196- dataset = TestCombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], seed = 12345 )
226+ dataset = TestCombinedStreamingDataset (
227+ datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], iterate_over_all = False , seed = 12345
228+ )
197229 dataloader = StreamingDataLoader (dataset , num_workers = 1 , batch_size = batch_size , prefetch_factor = 1 )
198230 dataloader_iter = iter (dataloader )
199231
@@ -260,7 +292,9 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
260292
261293 dataset1 = StreamingDataset (input_dir = Dir (cache_dir_1 , data_dir_1 ), shuffle = True )
262294 dataset2 = StreamingDataset (input_dir = Dir (cache_dir_2 , data_dir_2 ), shuffle = True )
263- dataset = CombinedStreamingDataset (datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], seed = 12345 )
295+ dataset = CombinedStreamingDataset (
296+ datasets = [dataset1 , dataset2 ], weights = [0.5 , 0.5 ], iterate_over_all = False , seed = 12345
297+ )
264298 dataloader = StreamingDataLoader (dataset , num_workers = 3 , batch_size = 2 )
265299
266300 assert dataset1 .current_epoch == 1
@@ -454,7 +488,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
454488 {
455489 "dataset" : {
456490 "0" : {
457- "num_samples_yielded" : 9 ,
491+ "num_samples_yielded" : 8 ,
458492 "num_workers" : 3 ,
459493 "batch_size" : 2 ,
460494 "current_epoch" : 1 ,
@@ -482,12 +516,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
482516 },
483517 "current_epoch" : 0 ,
484518 "latest_worker_idx" : 2 ,
485- "num_samples_yielded" : {0 : [3 , 1 ], 1 : [3 , 1 ], 2 : [3 , 1 ]},
519+ "num_samples_yielded" : {0 : [3 , 1 ], 1 : [3 , 1 ], 2 : [2 , 1 ]},
486520 },
487521 {
488522 "dataset" : {
489523 "0" : {
490- "num_samples_yielded" : 11 ,
524+ "num_samples_yielded" : 9 ,
491525 "num_workers" : 3 ,
492526 "batch_size" : 2 ,
493527 "current_epoch" : 1 ,
@@ -515,12 +549,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
515549 },
516550 "current_epoch" : 0 ,
517551 "latest_worker_idx" : 0 ,
518- "num_samples_yielded" : {0 : [5 , 1 ], 1 : [3 , 1 ], 2 : [3 , 1 ]},
552+ "num_samples_yielded" : {0 : [4 , 1 ], 1 : [3 , 1 ], 2 : [2 , 1 ]},
519553 },
520554 {
521555 "dataset" : {
522556 "0" : {
523- "num_samples_yielded" : 13 ,
557+ "num_samples_yielded" : 10 ,
524558 "num_workers" : 3 ,
525559 "batch_size" : 2 ,
526560 "current_epoch" : 1 ,
@@ -548,7 +582,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
548582 },
549583 "current_epoch" : 0 ,
550584 "latest_worker_idx" : 1 ,
551- "num_samples_yielded" : {0 : [5 , 1 ], 1 : [5 , 1 ], 2 : [3 , 1 ]},
585+ "num_samples_yielded" : {0 : [4 , 1 ], 1 : [4 , 1 ], 2 : [2 , 1 ]},
552586 },
553587 ]
554588
@@ -721,7 +755,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
721755 {
722756 "dataset" : {
723757 "0" : {
724- "num_samples_yielded" : 9 ,
758+ "num_samples_yielded" : 8 ,
725759 "num_workers" : 3 ,
726760 "batch_size" : 2 ,
727761 "current_epoch" : 2 ,
@@ -749,12 +783,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
749783 },
750784 "current_epoch" : 1 ,
751785 "latest_worker_idx" : 2 ,
752- "num_samples_yielded" : {0 : [3 , 1 ], 1 : [3 , 1 ], 2 : [3 , 1 ]},
786+ "num_samples_yielded" : {0 : [3 , 1 ], 1 : [3 , 1 ], 2 : [2 , 1 ]},
753787 },
754788 {
755789 "dataset" : {
756790 "0" : {
757- "num_samples_yielded" : 11 ,
791+ "num_samples_yielded" : 9 ,
758792 "num_workers" : 3 ,
759793 "batch_size" : 2 ,
760794 "current_epoch" : 2 ,
@@ -782,12 +816,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
782816 },
783817 "current_epoch" : 1 ,
784818 "latest_worker_idx" : 0 ,
785- "num_samples_yielded" : {0 : [5 , 1 ], 1 : [3 , 1 ], 2 : [3 , 1 ]},
819+ "num_samples_yielded" : {0 : [4 , 1 ], 1 : [3 , 1 ], 2 : [2 , 1 ]},
786820 },
787821 {
788822 "dataset" : {
789823 "0" : {
790- "num_samples_yielded" : 13 ,
824+ "num_samples_yielded" : 10 ,
791825 "num_workers" : 3 ,
792826 "batch_size" : 2 ,
793827 "current_epoch" : 2 ,
@@ -815,7 +849,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
815849 },
816850 "current_epoch" : 1 ,
817851 "latest_worker_idx" : 1 ,
818- "num_samples_yielded" : {0 : [5 , 1 ], 1 : [5 , 1 ], 2 : [3 , 1 ]},
852+ "num_samples_yielded" : {0 : [4 , 1 ], 1 : [4 , 1 ], 2 : [2 , 1 ]},
819853 },
820854 ]
821855
0 commit comments