-
Notifications
You must be signed in to change notification settings - Fork 465
Implements an OLMo-core compatible data loader to load HF datasets from get_cached_dataset_tulu.
#1208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implements an OLMo-core compatible data loader to load HF datasets from get_cached_dataset_tulu.
#1208
Changes from all commits
54365f8
473b2d8
e3da5e3
72777b0
6423648
fbe386d
bbed906
f27aab3
c8097fd
0b01af2
63253ba
5b71400
bdb2b44
6f43dde
1e3a6ac
2fe15a9
07421df
d283c5a
0f6c467
c2ac156
1390a5d
8da1b86
ba7b09d
6577e33
ceeee47
d023ea6
508ef2e
c277976
27301c1
2017ab5
1d2c925
86aa632
3107302
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| from collections.abc import Iterable | ||
| from typing import Any | ||
|
|
||
| from datasets import Dataset | ||
| from olmo_core.data import data_loader | ||
|
|
||
|
|
||
| class HFDataLoader(data_loader.DataLoaderBase): | ||
| """A DataLoader that wraps a HuggingFace Dataset for use with olmo_core's Trainer. | ||
|
|
||
| This class implements the DataLoaderBase interface, providing iteration over | ||
| a HuggingFace Dataset with support for sharding across distributed workers, | ||
| shuffling, and checkpointing. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dataset: Dataset, | ||
| batch_size: int, | ||
| seed: int, | ||
| rank: int, | ||
| world_size: int, | ||
| work_dir: str, | ||
| automatic_reshuffle: bool = False, | ||
| ) -> None: | ||
| """Initialize the HFDataLoader. | ||
|
|
||
| Args: | ||
| dataset: The HuggingFace Dataset to load data from. | ||
| batch_size: The global batch size. | ||
| seed: Random seed for shuffling. | ||
| rank: The rank of the current process in the distributed setup. | ||
| world_size: Total number of processes in the distributed setup. | ||
| work_dir: Working directory for the data loader (required by DataLoaderBase). | ||
| automatic_reshuffle: If True, automatically reshuffle at epoch boundaries. | ||
| """ | ||
| super().__init__( | ||
| work_dir=work_dir, global_batch_size=batch_size, dp_world_size=world_size, dp_rank=rank, fs_local_rank=0 | ||
| ) | ||
|
|
||
| dataset_with_indices = dataset.map(lambda example, idx: example | {"dataset_index": idx}, with_indices=True) | ||
| self._original_dataset = dataset_with_indices.shard(num_shards=world_size, index=rank) | ||
| self.dataset = self._original_dataset.shuffle(seed=seed) | ||
| self.seed = seed | ||
| self._batch_size = batch_size | ||
| self.effective_size = len(self.dataset) - (len(self.dataset) % batch_size) | ||
| self._automatic_reshuffle = automatic_reshuffle | ||
| self._excluded_indices: set[int] = set() | ||
| self._epoch = 0 | ||
| self._current_iter: Iterable[dict[str, Any]] | None = None | ||
|
|
||
| def __next__(self) -> dict[str, Any]: | ||
| if self._current_iter is None: | ||
| self._current_iter = self._iter_batches() | ||
| try: | ||
| return next(self._current_iter) | ||
| except StopIteration: | ||
| if self._automatic_reshuffle: | ||
| self.reshuffle() | ||
| self._current_iter = self._iter_batches() | ||
| return next(self._current_iter) | ||
| raise | ||
|
|
||
| def _iter_batches(self) -> Iterable[dict[str, Any]]: | ||
| """Return an iterable over all batches in the epoch.""" | ||
| for i in range(self.batches_processed, self.effective_size): | ||
| example = self.dataset[i] | ||
| yield example | {"prompt_id": f"{self._epoch}_{example['dataset_index']}"} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: batches_processed never incremented during iterationThe |
||
|
|
||
| @property | ||
| def total_batches(self) -> int: | ||
| """Return the total number of batches in an epoch.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this doc string right? Isn't effective size number of samples in the dataset, from line 38?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, you're right. Fixed! |
||
| return self.effective_size // self._batch_size | ||
|
|
||
| def state_dict(self) -> dict[str, Any]: | ||
| """Return a state dictionary for checkpointing.""" | ||
| return { | ||
| "epoch": self._epoch, | ||
| "batches_processed": self.batches_processed, | ||
| "excluded_indices": list(self._excluded_indices), | ||
| } | ||
|
|
||
| def load_state_dict(self, state: dict[str, Any]) -> None: | ||
| """Load a state dictionary to restore the data loader's state.""" | ||
| self._excluded_indices = set(state.get("excluded_indices", [])) | ||
| # Set epoch to one less than target since reshuffle() increments it | ||
| self._epoch = state["epoch"] - 1 | ||
| self.reshuffle() | ||
| assert self._epoch == state["epoch"] | ||
| self.batches_processed = state["batches_processed"] | ||
| self._current_iter = None | ||
|
|
||
| def exclude_index(self, index: int) -> None: | ||
| """Exclude a dataset index from future iterations. | ||
|
|
||
| Args: | ||
| index: The dataset_index to exclude. | ||
| """ | ||
| self._excluded_indices.add(index) | ||
|
|
||
| def reshuffle(self, **kwargs: Any) -> None: | ||
| """Reshuffle the dataset for a new epoch. | ||
|
|
||
| Args: | ||
| **kwargs: Additional keyword arguments (unused, for API compatibility). | ||
| """ | ||
| self._epoch += 1 | ||
| self.batches_processed = 0 | ||
| shuffled = self._original_dataset.shuffle(seed=self.seed + self._epoch) | ||
| if self._excluded_indices: | ||
| self.dataset = shuffled.filter(lambda x: x["dataset_index"] not in self._excluded_indices) | ||
| else: | ||
| self.dataset = shuffled | ||
| self.effective_size = len(self.dataset) - (len(self.dataset) % self._batch_size) | ||
|
|
||
| def get_mock_batch(self) -> dict[str, Any]: | ||
| """Return a batch with arbitrary data for dry-run testing. | ||
|
|
||
| Used by the trainer to do a dry-run of the | ||
| forward and backward pass before training officially starts. | ||
|
|
||
| Returns: | ||
| The first item from the dataset. | ||
| """ | ||
| return self.dataset[0] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Evaluation data loader exhausted after first use
When
automatic_reshuffle=False(the default), after the first complete iteration through the dataset,_current_iterremains as an exhausted generator and is never reset. Subsequent calls to__next__will immediately raiseStopIterationwithout yielding any items. This breaks evaluation ingrpo_fast.pywhereeval_data_loaderis created withoutautomatic_reshuffle=Trueand is iterated over multiple times during training. After the first evaluation completes, all subsequent evaluations will process zero examples because the iterator is exhausted.