Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
54365f8
test pass now
finbarrtimbers Nov 18, 2025
473b2d8
Updated code
finbarrtimbers Nov 18, 2025
e3da5e3
updated code
finbarrtimbers Nov 18, 2025
72777b0
Added data loader
finbarrtimbers Nov 18, 2025
6423648
cleaned up code
finbarrtimbers Nov 18, 2025
fbe386d
Added get_mock_batch impl.
finbarrtimbers Nov 18, 2025
bbed906
Updated to use proper iterator api
finbarrtimbers Nov 18, 2025
f27aab3
Fixed caching.
finbarrtimbers Nov 18, 2025
c8097fd
fixed dataset issue
finbarrtimbers Nov 18, 2025
0b01af2
added debugging
finbarrtimbers Nov 19, 2025
63253ba
fixed iter abtches
finbarrtimbers Nov 19, 2025
5b71400
Added tests for data loader
finbarrtimbers Nov 19, 2025
bdb2b44
Added tests
finbarrtimbers Nov 19, 2025
6f43dde
Simplified filter logic.
finbarrtimbers Nov 19, 2025
1e3a6ac
Merge branch 'main' into oc-dataloader-dataset
finbarrtimbers Nov 24, 2025
2fe15a9
Removed dataset
finbarrtimbers Nov 24, 2025
07421df
updated cluster
finbarrtimbers Nov 24, 2025
d283c5a
Merge branch 'main' into oc-dataloader-dataset
finbarrtimbers Nov 24, 2025
0f6c467
Added docstrings, function signature type annotations.
finbarrtimbers Nov 24, 2025
c2ac156
Added docstrings + function signature type annotations
finbarrtimbers Nov 24, 2025
1390a5d
updated code
finbarrtimbers Nov 24, 2025
8da1b86
Cleans up docstrings.
finbarrtimbers Nov 24, 2025
ba7b09d
Cleaned up code.
finbarrtimbers Nov 24, 2025
6577e33
updated code
finbarrtimbers Nov 24, 2025
ceeee47
Cleaned up code
finbarrtimbers Nov 25, 2025
d023ea6
udpated test
finbarrtimbers Nov 25, 2025
508ef2e
Merge branch 'main' into oc-dataloader-dataset
finbarrtimbers Nov 26, 2025
c277976
Fixed tests.
finbarrtimbers Nov 26, 2025
27301c1
updated docstring
finbarrtimbers Nov 26, 2025
2017ab5
Cleaned up code.
finbarrtimbers Nov 26, 2025
1d2c925
Fixed bugs.
finbarrtimbers Nov 26, 2025
86aa632
Fixed bug cursor pointed out.
finbarrtimbers Nov 26, 2025
3107302
now, everything should be fixed! addressed all of cursor's comments.
finbarrtimbers Nov 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions open_instruct/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections.abc import Iterable
from typing import Any

from datasets import Dataset
from olmo_core.data import data_loader


class HFDataLoader(data_loader.DataLoaderBase):
"""A DataLoader that wraps a HuggingFace Dataset for use with olmo_core's Trainer.

This class implements the DataLoaderBase interface, providing iteration over
a HuggingFace Dataset with support for sharding across distributed workers,
shuffling, and checkpointing.
"""

def __init__(
self,
dataset: Dataset,
batch_size: int,
seed: int,
rank: int,
world_size: int,
work_dir: str,
automatic_reshuffle: bool = False,
) -> None:
"""Initialize the HFDataLoader.

Args:
dataset: The HuggingFace Dataset to load data from.
batch_size: The global batch size.
seed: Random seed for shuffling.
rank: The rank of the current process in the distributed setup.
world_size: Total number of processes in the distributed setup.
work_dir: Working directory for the data loader (required by DataLoaderBase).
automatic_reshuffle: If True, automatically reshuffle at epoch boundaries.
"""
super().__init__(
work_dir=work_dir, global_batch_size=batch_size, dp_world_size=world_size, dp_rank=rank, fs_local_rank=0
)

dataset_with_indices = dataset.map(lambda example, idx: example | {"dataset_index": idx}, with_indices=True)
self._original_dataset = dataset_with_indices.shard(num_shards=world_size, index=rank)
self.dataset = self._original_dataset.shuffle(seed=seed)
self.seed = seed
self._batch_size = batch_size
self.effective_size = len(self.dataset) - (len(self.dataset) % batch_size)
self._automatic_reshuffle = automatic_reshuffle
self._excluded_indices: set[int] = set()
self._epoch = 0
self._current_iter: Iterable[dict[str, Any]] | None = None

def __next__(self) -> dict[str, Any]:
if self._current_iter is None:
self._current_iter = self._iter_batches()
try:
return next(self._current_iter)
except StopIteration:
if self._automatic_reshuffle:
self.reshuffle()
self._current_iter = self._iter_batches()
return next(self._current_iter)
raise
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Evaluation data loader exhausted after first use

When automatic_reshuffle=False (the default), after the first complete iteration through the dataset, _current_iter remains as an exhausted generator and is never reset. Subsequent calls to __next__ will immediately raise StopIteration without yielding any items. This breaks evaluation in grpo_fast.py where eval_data_loader is created without automatic_reshuffle=True and is iterated over multiple times during training. After the first evaluation completes, all subsequent evaluations will process zero examples because the iterator is exhausted.

Fix in Cursor Fix in Web


def _iter_batches(self) -> Iterable[dict[str, Any]]:
"""Return an iterable over all batches in the epoch."""
for i in range(self.batches_processed, self.effective_size):
example = self.dataset[i]
yield example | {"prompt_id": f"{self._epoch}_{example['dataset_index']}"}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: batches_processed never incremented during iteration

The _iter_batches method yields items but never increments self.batches_processed. This breaks checkpointing because state_dict will always save batches_processed at its initial value (0 or whatever was restored). When resuming from a checkpoint, the data loader will restart from the beginning of the epoch instead of continuing from where it left off. The loop variable i should be used to update batches_processed after each yield, or batches_processed should be incremented in the __next__ method.

Fix in Cursor Fix in Web


@property
def total_batches(self) -> int:
"""Return the total number of batches in an epoch."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this doc string right? Isn't effective size number of samples in the dataset, from line 38?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you're right. Fixed!

return self.effective_size // self._batch_size

def state_dict(self) -> dict[str, Any]:
"""Return a state dictionary for checkpointing."""
return {
"epoch": self._epoch,
"batches_processed": self.batches_processed,
"excluded_indices": list(self._excluded_indices),
}

def load_state_dict(self, state: dict[str, Any]) -> None:
"""Load a state dictionary to restore the data loader's state."""
self._excluded_indices = set(state.get("excluded_indices", []))
# Set epoch to one less than target since reshuffle() increments it
self._epoch = state["epoch"] - 1
self.reshuffle()
assert self._epoch == state["epoch"]
self.batches_processed = state["batches_processed"]
self._current_iter = None

def exclude_index(self, index: int) -> None:
"""Exclude a dataset index from future iterations.

Args:
index: The dataset_index to exclude.
"""
self._excluded_indices.add(index)

def reshuffle(self, **kwargs: Any) -> None:
"""Reshuffle the dataset for a new epoch.

Args:
**kwargs: Additional keyword arguments (unused, for API compatibility).
"""
self._epoch += 1
self.batches_processed = 0
shuffled = self._original_dataset.shuffle(seed=self.seed + self._epoch)
if self._excluded_indices:
self.dataset = shuffled.filter(lambda x: x["dataset_index"] not in self._excluded_indices)
else:
self.dataset = shuffled
self.effective_size = len(self.dataset) - (len(self.dataset) % self._batch_size)

def get_mock_batch(self) -> dict[str, Any]:
"""Return a batch with arbitrary data for dry-run testing.

Used by the trainer to do a dry-run of the
forward and backward pass before training officially starts.

Returns:
The first item from the dataset.
"""
return self.dataset[0]
Loading