Skip to content

Commit d2802bd

Browse files
authored
Fix: Resolve drop_last not passed down from the StreamingDataLoader to the datasets (#147)
1 parent bb362a0 commit d2802bd

File tree

6 files changed

+36
-2
lines changed

6 files changed

+36
-2
lines changed

src/litdata/streaming/combined.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def set_shuffle(self, shuffle: bool) -> None:
117117
for dataset in self._datasets:
118118
dataset.set_shuffle(shuffle)
119119

120+
def set_drop_last(self, drop_last: bool) -> None:
121+
"""Set the current drop_last to the datasets."""
122+
for dataset in self._datasets:
123+
dataset.set_drop_last(drop_last)
124+
120125
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
121126
if any(not isinstance(d, StreamingDataset) for d in datasets):
122127
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

src/litdata/streaming/dataloader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
profile_dir: Optional[str] = None,
541541
prefetch_factor: Optional[int] = None,
542542
shuffle: Optional[bool] = None,
543+
drop_last: Optional[bool] = False,
543544
**kwargs: Any,
544545
) -> None: # pyright: ignore
545546
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -551,6 +552,9 @@ def __init__(
551552
if shuffle is not None:
552553
dataset.set_shuffle(shuffle)
553554

555+
if drop_last is not None:
556+
dataset.set_drop_last(drop_last)
557+
554558
shuffle = None
555559

556560
if profile_batches and not _VIZ_TRACKER_AVAILABLE:

src/litdata/streaming/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __init__(
113113
def set_shuffle(self, shuffle: bool) -> None:
114114
self.shuffle = shuffle
115115

116+
def set_drop_last(self, drop_last: bool) -> None:
117+
self.drop_last = drop_last
118+
116119
def set_epoch(self, current_epoch: int) -> None:
117120
"""Set the current epoch to the dataset on epoch starts.
118121

status.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

tests/streaming/test_combined.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import sys
3-
from unittest.mock import ANY
3+
from unittest.mock import ANY, MagicMock
44

55
import pytest
66
import torch
@@ -53,6 +53,19 @@ def test_combined_dataset_num_samples_yield_iterate_over_all():
5353
assert len(samples) == 20
5454

5555

56+
def test_drop_last_and_shuffle():
57+
dataset_mock_1 = MagicMock()
58+
dataset_mock_2 = MagicMock()
59+
60+
dataset = TestCombinedStreamingDataset([dataset_mock_1, dataset_mock_2], 42, iterate_over_all=True)
61+
StreamingDataLoader(dataset, shuffle=True, drop_last=True)
62+
63+
dataset_mock_1.set_shuffle.assert_called()
64+
dataset_mock_2.set_shuffle.assert_called()
65+
dataset_mock_1.set_drop_last.assert_called()
66+
dataset_mock_2.set_drop_last.assert_called()
67+
68+
5669
class TestStatefulDataset:
5770
def __init__(self, size, step):
5871
self.size = size
@@ -177,6 +190,12 @@ def state_dict(self, **kwargs):
177190
def set_epoch(self, current_epoch):
178191
pass
179192

193+
def set_shuffle(self, _):
194+
pass
195+
196+
def set_drop_last(self, _):
197+
pass
198+
180199

181200
def test_combined_dataset():
182201
dataset1 = SimpleDataset(0, 10)

tests/streaming/test_dataloader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, size, step):
1414
self.step = step
1515
self.counter = 0
1616
self.shuffle = None
17+
self.drop_last = None
1718

1819
def set_shuffle(self, shuffle):
1920
self.shuffle = shuffle
@@ -41,6 +42,9 @@ def load_state_dict(self, state_dict):
4142
def set_epoch(self, current_epoch):
4243
pass
4344

45+
def set_drop_last(self, drop_last):
46+
self.drop_last = drop_last
47+
4448

4549
class TestCombinedStreamingDataset(CombinedStreamingDataset):
4650
def _check_datasets(self, datasets) -> None:

0 commit comments

Comments
 (0)