-
Notifications
You must be signed in to change notification settings - Fork 125
datareader : sampling with moving weights #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main_w2v2_pretraining
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,18 +8,20 @@ | |
|
|
||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Iterator, Mapping | ||
| from typing import TypeVar, final | ||
| from typing import Any, Callable, Dict, TypeVar, final | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from typing_extensions import Self, override | ||
|
|
||
| from fairseq2.data import DataPipeline, DataPipelineError | ||
| from fairseq2.gang import Gang, GangError, all_sum | ||
| from fairseq2.logging import log | ||
|
|
||
| # isort: split | ||
|
|
||
| from typing_extensions import Self, override | ||
|
|
||
| from fairseq2.datasets._config import DataReadOptions, SyncMode | ||
| from fairseq2.gang import Gang, GangError, all_sum | ||
| from fairseq2.logging import log | ||
|
|
||
| BatchT_co = TypeVar("BatchT_co", covariant=True) | ||
|
|
||
|
|
@@ -197,6 +199,173 @@ | |
| return self._options.num_accumulate | ||
|
|
||
|
|
||
| @final | ||
| class ScheduledDataPipelineReader(DataReader[BatchT]): | ||
| """Reads batches of examples from a dataset using a :class:`DataPipeline`.""" | ||
|
|
||
| _dataset_name: str | ||
| _split: str | ||
| _pipelines: Dict[str, DataPipeline] | ||
| _weights_schedulers: Dict[str, Callable[[int], float]] | ||
| _pipelines_iter: Dict[str, Iterator[BatchT]] | ||
| _gang: Gang | ||
| _options: DataReadOptions | ||
| _eod: bool | ||
| _seed: int | ||
|
|
||
| def __init__( | ||
| self, | ||
| dataset_name: str, | ||
| split: str, | ||
| pipelines: Dict[str, DataPipeline], | ||
| weights_schedulers: Dict[str, Callable[[int], float]], | ||
| gang: Gang, | ||
| options: DataReadOptions, | ||
| *, | ||
| strict_state: bool = True, | ||
| ) -> None: | ||
| self._dataset_name = dataset_name | ||
| self._split = split | ||
|
|
||
| self._pipeline = pipelines | ||
| self._weights_schedulers = weights_schedulers | ||
| self._check_keys() | ||
|
|
||
| self._pipeline_iters = { | ||
| name: iter(pipeline) for name, pipeline in pipelines.items() | ||
| } | ||
| self._names = np.array(list(self._pipeline.keys()), dtype=object) | ||
| self._gang = gang | ||
| self._options = options | ||
| self._eod = False | ||
| self._strict_state = strict_state | ||
| self._seed = options.seed | ||
| self._rng = np.random.default_rng(self._seed) # add to state dict | ||
| self._step = 0 # add to state dict | ||
|
|
||
| def _check_keys(self) -> None: | ||
| if set(self._pipelines.keys()) != set(self._weights_schedulers.keys()): | ||
| raise ValueError( | ||
| "The keys of the pipelines and weights schedulers must be the same." | ||
| ) | ||
|
|
||
| def _next_dataset_name(self) -> str: | ||
| weights = np.array( | ||
| [scheduler(self._step) for scheduler in self._weights_schedulers.values()], | ||
| dtype=np.float64, | ||
| ) | ||
| weights = weights / weights.sum() | ||
| self._step += 1 | ||
| return self._rng.choice(self._names, p=weights) | ||
|
|
||
| @override | ||
| def __iter__(self) -> Self: | ||
| return self | ||
|
|
||
| @override | ||
| def __next__(self) -> list[BatchT]: | ||
| if self._eod: | ||
| raise StopIteration() | ||
|
|
||
| batches = [] | ||
|
|
||
| num_accumulate = self._options.num_accumulate | ||
|
|
||
| for idx in range(num_accumulate): | ||
| next_name = self._next_dataset_name() | ||
| try: | ||
| batch = next(self._pipeline_iters[next_name]) | ||
| except StopIteration: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when would we expect this to occur? when the pipeline has epoched through all data? |
||
| self._pipelines[next_name].reset() # restart the pipeline | ||
|
|
||
| try: | ||
| batch = next(self._pipeline_iters[next_name]) | ||
| except StopIteration: | ||
| # XXX: normally it always should get some data after reset | ||
| raise DataReadError( | ||
| self._dataset_name, | ||
| self._split, | ||
| f"No data provided for {next_name}", | ||
| ) | ||
| except DataPipelineError as ex: | ||
| raise DataReadError( | ||
| self._dataset_name, self._split, f"The data pipeline has failed to read the next batch from the '{self._split}' split of the '{self._dataset_name}' dataset. See the nested exception for details." # fmt: skip | ||
| ) from ex | ||
|
|
||
| batches.append(batch) | ||
|
|
||
| # If we read less than `num_accumulate` batches, it means we reached end | ||
| # of data. | ||
| if self._options.drop_remainder and len(batches) != num_accumulate: | ||
| batches.clear() | ||
|
Comment on lines
+297
to
+300
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm. a bit confused here. some clarifying questions:
|
||
|
|
||
| local_num_batches = len(batches) | ||
|
|
||
| if self._options.sync_batches and self._gang.size > 1: | ||
| try: | ||
| if self._options.sync_mode == SyncMode.UNTIL_LAST: | ||
| num_batches = _sum_num_batches(local_num_batches, self._gang) | ||
| else: | ||
| num_batches = _min_num_batches(local_num_batches, self._gang) | ||
|
|
||
| if num_batches != local_num_batches: | ||
| batches = batches[:num_batches] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this imply that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the general fs2 logic to work with pipelines epoch's end : for some rank the data will be returned but others should skip their turn. |
||
| except GangError as ex: | ||
| raise DataReadError( | ||
| self._dataset_name, self._split, f"The batch synchronization of the gang processes has failed while reading the '{self._split}' split of the '{self._dataset_name}' dataset. See the nested exception for details." # fmt: skip | ||
| ) from ex | ||
| else: | ||
| num_batches = local_num_batches | ||
|
|
||
| self._eod = num_batches == 0 | ||
|
|
||
| if self._eod: | ||
| raise StopIteration() | ||
|
|
||
| return batches | ||
|
|
||
| @override | ||
| def reset(self) -> None: | ||
| self._eod = False | ||
| for pipeline in self._pipelines.values(): | ||
| pipeline.reset() | ||
| self._step = 0 | ||
| self._rng = np.random.default_rng(self._seed) | ||
|
|
||
| @override | ||
| def state_dict(self) -> Dict[str, object]: | ||
| state: Dict[str, Any] = { | ||
| name: pipeline.state_dict(strict=self._strict_state) | ||
| for name, pipeline in self._pipelines.items() | ||
| } | ||
| state["step"] = self._step | ||
| state["rng"] = self._rng | ||
| return state | ||
|
|
||
| @override | ||
| def load_state_dict(self, state_dict: Mapping[str, object]) -> None: | ||
| self._eod = False | ||
| for name, pipeline in self._pipelines.items(): | ||
| pipeline.load_state_dict(state_dict[name]) | ||
| self._step = state_dict["step"] | ||
| self._rng = state_dict["rng"] | ||
|
|
||
| @property | ||
| @override | ||
| def dataset_name(self) -> str: | ||
| return self._dataset_name | ||
|
|
||
| @property | ||
| @override | ||
| def split(self) -> str: | ||
| return self._split | ||
|
|
||
| @property | ||
| @override | ||
| def num_accumulate(self) -> int: | ||
| return self._options.num_accumulate | ||
|
|
||
|
|
||
| def _min_num_batches(num_batches: int, gang: Gang) -> int: | ||
| all_num_batches = torch.zeros((gang.size,), device=gang.device, dtype=torch.int64) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i like the setup of making this callable. instead of having
schedulers for each data pipeline, we could have one master scheduler function that rights the list of weights. but either works, i suppose this option you've implemented here gives us more flexibility.