Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 173 additions & 4 deletions src/fairseq2/datasets/_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@

from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from typing import TypeVar, final
from typing import Any, Callable, Dict, TypeVar, final

import numpy as np
import torch
from typing_extensions import Self, override

from fairseq2.data import DataPipeline, DataPipelineError
from fairseq2.gang import Gang, GangError, all_sum
from fairseq2.logging import log

# isort: split

from typing_extensions import Self, override

from fairseq2.datasets._config import DataReadOptions, SyncMode
from fairseq2.gang import Gang, GangError, all_sum
from fairseq2.logging import log

BatchT_co = TypeVar("BatchT_co", covariant=True)

Expand Down Expand Up @@ -197,6 +199,173 @@
return self._options.num_accumulate


@final
class ScheduledDataPipelineReader(DataReader[BatchT]):
"""Reads batches of examples from a dataset using a :class:`DataPipeline`."""

_dataset_name: str
_split: str
_pipelines: Dict[str, DataPipeline]
_weights_schedulers: Dict[str, Callable[[int], float]]
_pipelines_iter: Dict[str, Iterator[BatchT]]
_gang: Gang
_options: DataReadOptions
_eod: bool
_seed: int

def __init__(
self,
dataset_name: str,
split: str,
pipelines: Dict[str, DataPipeline],
weights_schedulers: Dict[str, Callable[[int], float]],
gang: Gang,
options: DataReadOptions,
*,
strict_state: bool = True,
) -> None:
self._dataset_name = dataset_name
self._split = split

self._pipeline = pipelines
self._weights_schedulers = weights_schedulers
self._check_keys()

self._pipeline_iters = {
name: iter(pipeline) for name, pipeline in pipelines.items()
}
self._names = np.array(list(self._pipeline.keys()), dtype=object)
self._gang = gang
self._options = options
self._eod = False
self._strict_state = strict_state
self._seed = options.seed
self._rng = np.random.default_rng(self._seed) # add to state dict
self._step = 0 # add to state dict

def _check_keys(self) -> None:
if set(self._pipelines.keys()) != set(self._weights_schedulers.keys()):
raise ValueError(
"The keys of the pipelines and weights schedulers must be the same."
)

def _next_dataset_name(self) -> str:
weights = np.array(
[scheduler(self._step) for scheduler in self._weights_schedulers.values()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like the setup of making this callable. instead of having schedulers for each data pipeline, we could have one master scheduler function that rights the list of weights. but either works, i suppose this option you've implemented here gives us more flexibility.

dtype=np.float64,
)
weights = weights / weights.sum()
self._step += 1
return self._rng.choice(self._names, p=weights)

Check failure on line 259 in src/fairseq2/datasets/_data_reader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Returning Any from function declared to return "str"

@override
def __iter__(self) -> Self:
return self

@override
def __next__(self) -> list[BatchT]:
if self._eod:
raise StopIteration()

batches = []

num_accumulate = self._options.num_accumulate

for idx in range(num_accumulate):
next_name = self._next_dataset_name()
try:
batch = next(self._pipeline_iters[next_name])
except StopIteration:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would we expect this to occur? when the pipeline has epoched through all data?

self._pipelines[next_name].reset() # restart the pipeline

try:
batch = next(self._pipeline_iters[next_name])
except StopIteration:
# XXX: normally it always should get some data after reset
raise DataReadError(
self._dataset_name,
self._split,
f"No data provided for {next_name}",
)
except DataPipelineError as ex:
raise DataReadError(
self._dataset_name, self._split, f"The data pipeline has failed to read the next batch from the '{self._split}' split of the '{self._dataset_name}' dataset. See the nested exception for details." # fmt: skip
) from ex

batches.append(batch)

# If we read less than `num_accumulate` batches, it means we reached end
# of data.
if self._options.drop_remainder and len(batches) != num_accumulate:
batches.clear()
Comment on lines +297 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm. a bit confused here. some clarifying questions:

  1. when would this condition be True? isn't that case handled by the above exception handling?

  2. dumb Q: what does batches.clear() do? (just empty the batch List?)

  3. at a conceptual level, seems like this logic is used to determine if we should halt training once we reach end of data (or when any of the pipelines reaches end of data), if self._options.drop_remainder is set to True (and if False, we just continue epoching the data). is this understanding correct?


local_num_batches = len(batches)

if self._options.sync_batches and self._gang.size > 1:
try:
if self._options.sync_mode == SyncMode.UNTIL_LAST:
num_batches = _sum_num_batches(local_num_batches, self._gang)
else:
num_batches = _min_num_batches(local_num_batches, self._gang)

if num_batches != local_num_batches:
batches = batches[:num_batches]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this imply that batches is a global list of batches across all ranks? i thought this was a local list...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the general fs2 logic to work with pipelines epoch's end : for some rank the data will be returned but others should skip their turn.
Probably for that scheduled mixture we dont need that since we're dealing with infinite loop here.

except GangError as ex:
raise DataReadError(
self._dataset_name, self._split, f"The batch synchronization of the gang processes has failed while reading the '{self._split}' split of the '{self._dataset_name}' dataset. See the nested exception for details." # fmt: skip
) from ex
else:
num_batches = local_num_batches

self._eod = num_batches == 0

if self._eod:
raise StopIteration()

return batches

@override
def reset(self) -> None:
self._eod = False
for pipeline in self._pipelines.values():
pipeline.reset()
self._step = 0
self._rng = np.random.default_rng(self._seed)

@override
def state_dict(self) -> Dict[str, object]:
state: Dict[str, Any] = {
name: pipeline.state_dict(strict=self._strict_state)
for name, pipeline in self._pipelines.items()
}
state["step"] = self._step
state["rng"] = self._rng
return state

@override
def load_state_dict(self, state_dict: Mapping[str, object]) -> None:
self._eod = False
for name, pipeline in self._pipelines.items():
pipeline.load_state_dict(state_dict[name])

Check failure on line 349 in src/fairseq2/datasets/_data_reader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "load_state_dict" of "DataPipeline" has incompatible type "object"; expected "Mapping[str, Any]"
self._step = state_dict["step"]

Check failure on line 350 in src/fairseq2/datasets/_data_reader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible types in assignment (expression has type "object", variable has type "int")
self._rng = state_dict["rng"]

Check failure on line 351 in src/fairseq2/datasets/_data_reader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible types in assignment (expression has type "object", variable has type "Generator")

@property
@override
def dataset_name(self) -> str:
return self._dataset_name

@property
@override
def split(self) -> str:
return self._split

@property
@override
def num_accumulate(self) -> int:
return self._options.num_accumulate


def _min_num_batches(num_batches: int, gang: Gang) -> int:
all_num_batches = torch.zeros((gang.size,), device=gang.device, dtype=torch.int64)

Expand Down
3 changes: 1 addition & 2 deletions src/fairseq2/metrics/recorders/_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from fairseq2.logging import log
from fairseq2.metrics import MetricDescriptor

from fairseq2.metrics.recorders._handler import MetricRecorderHandler
from fairseq2.metrics.recorders._recorder import (
MetricRecorder,
Expand All @@ -22,7 +21,7 @@
)
from fairseq2.registry import Provider
from fairseq2.utils.structured import structure
from fairseq2.utils.validation import validate, ValidationError, ValidationResult
from fairseq2.utils.validation import ValidationError, ValidationResult, validate

# isort: split

Expand Down
3 changes: 1 addition & 2 deletions src/fairseq2/models/asr/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class AsrModel(Module, ABC):
"""Represents an Automatic Speech Recognition model."""

@abstractmethod
def forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput:
...
def forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput: ...


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/models/wav2vec2/asr/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from typing import final

from torch.nn import Dropout
from typing_extensions import override

from fairseq2.models.asr import AsrModel, AsrModelOutput
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch
Expand All @@ -16,9 +19,6 @@
from fairseq2.nn import Projection
from fairseq2.typing import DataType, Device

from torch.nn import Dropout
from typing_extensions import override


@final
class Wav2Vec2AsrModel(AsrModel):
Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/recipes/asr/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from __future__ import annotations

import math
from typing import Any, Dict, final, TextIO
from typing import Any, Dict, TextIO, final

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
from fairseq2.gang import Gang
Expand All @@ -20,8 +22,6 @@
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.recipes import BaseMetricBag, Model, UnitError
from torch import Tensor
from typing_extensions import override


@final
Expand Down
5 changes: 3 additions & 2 deletions src/fairseq2/recipes/asr/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from fairseq2.context import RuntimeContext
from fairseq2.datasets import LengthBatching, SyncMode
from fairseq2.datasets.asr import AsrDataset, AsrReadOptions, GENERIC_ASR_DATASET_FAMILY
from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, AsrDataset, AsrReadOptions
from fairseq2.gang import Gangs
from fairseq2.models.asr import AsrModel
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.recipes import Evaluator, EvalUnit, Model, RecipeError, UnitError

# isort: split

from typing_extensions import override

from fairseq2.recipes.asr._common import AsrCriterion, AsrMetricBag, AsrScorer
from fairseq2.recipes.common import (
create_evaluator,
Expand All @@ -45,7 +47,6 @@
from fairseq2.utils.rng import manual_seed
from fairseq2.utils.structured import structure
from fairseq2.utils.validation import validate
from typing_extensions import override


@dataclass(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from fairseq2.context import RuntimeContext
from fairseq2.data import DataPipeline
from fairseq2.data.text.tokenizers import TextTokenizer
from fairseq2.datasets import (
DataPipelineReader,
)
from fairseq2.datasets import DataPipelineReader
from fairseq2.datasets.asr import AsrDataset, AsrReadOptions
from fairseq2.gang import Gang, Gangs
from fairseq2.models.seq2seq import Seq2SeqBatch
Expand Down
Loading