diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index e9ddc625a..e626648db 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -32,7 +32,16 @@ training: max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +eval: + eval_every_n_steps: 50 # null = disabled + max_eval_steps: null # null = run until epoch completes + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[95%:]" parallelism: data_parallel_replicate_degree: 1 @@ -62,6 +71,7 @@ metric_logging: group: sft_exp_${oc.env:USER} logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + # profiling: # enable_profiling: false diff --git a/apps/sft/main.py b/apps/sft/main.py index 93ba05eed..dc9d0e181 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -27,6 +27,7 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.data.utils import StopAfterOneEpoch from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse @@ -81,6 +82,7 @@ def __init__(self, config: DictConfig): self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) + self._init_dist() super().__init__(job_config) @@ -122,28 +124,67 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - self.train_dataloader = self.setup_data() + + # metric logger self.mlogger = await self.setup_metric_logger() - # self.train_dataloader = self.setup_data( - # self.train_config.train_dataset_config, - # self.train_config.train_dataloader_config, - # self.train_config.packing_config, - # ) - # self.val_dataloader = self.setup_data( - # self.train_config.val_dataset_config, - # self.train_config.val_dataloader_config, - # self.train_config.packing_config, - # ) + # Load training datasets + logger.info("Setting training datasets") + train_datasets_config = self.job_config.training.datasets + self.train_dataloader = self.setup_data(train_datasets_config) + + # Load eval datasets + eval_config = self.job_config.get("eval", {}) + self.val_dataloaders = {} + self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) + max_eval_steps = eval_config.get("max_eval_steps", None) + self.max_eval_steps = ( + max_eval_steps if max_eval_steps and max_eval_steps > 0 else None + ) + self.validation_enabled = ( + self.eval_every_n_steps is not None and self.eval_every_n_steps > 0 + ) + if self.validation_enabled: + logger.info("Setting eval datasets") + self.eval_datasets_config = eval_config.datasets + + for i, dataset_config in enumerate(self.eval_datasets_config): + ds_name = dataset_config.get("dataset_name", i) + + # TODO: Support separate eval batch size from config (eval.local_batch_size) + dataloader = self.setup_data([dataset_config]) + self.val_dataloaders[ds_name] = dataloader # TODO: confirm that this is working properly # Should also use load, not dcp_load self.checkpointer.load(step=self.current_step) + # self.profiler = self.setup_profiler(self.train_config.profiler_config) # self.logger = self.setup_logger(self.train_config.logger_config) - def setup_data(self): - print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) + def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: + """Instantiates datasets and returns a StatefulDataLoader. + + Args: + dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`. + + Returns: + StatefulDataLoader + + Raises: + ValueError: If multiple datasets provided (not yet supported) + """ + # TODO felipemello: Currently only support single dataset + if len(dataset_configs) > 1: + raise ValueError( + f"Multiple training datasets not supported yet. " + f"Got {len(dataset_configs)} datasets. " + ) + + dataset_config = dataset_configs[0] + + # TODO: Evaluate if tokenizers should be created once and shared for every dataset + # Load tokenizer tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( self.job_config.model.hf_assets_path, "tokenizer.json" @@ -165,18 +206,26 @@ def setup_data(self): ), ) + # Get DP mesh for data sharding + dp_mesh = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_mesh = self.parallel_dims.world_mesh.get_group("dp") + + # Pass config directly to dataset constructor dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + dp_mesh=dp_mesh, + **dataset_config, ) + packer = TextPacker(padding_idx=0) dataset = PackedDataset( dataset=dataset, packer=packer, target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model ) + dataloader = StatefulDataLoader( dataset=dataset, batch_size=self.job_config.training.local_batch_size, @@ -192,8 +241,12 @@ def setup_data(self): return dataloader def forward_backward( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + skip_backward: bool = False, ) -> torch.Tensor: + """Forward pass with optional backward.""" model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -214,6 +267,7 @@ def forward_backward( if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call + # Note: PP backward only happens if not in torch.no_grad() context with self.train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) @@ -235,7 +289,7 @@ def forward_backward( else torch.tensor([-1.0], device=self.device) ) else: - # Non-PP forward / backward + # Non-PP forward / backward - must happen inside same context with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -243,7 +297,10 @@ def forward_backward( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - loss.backward() + + # Only run backward if requested. Useful for eval. + if not skip_backward: + loss.backward() return loss @@ -259,12 +316,108 @@ def train_step(self, batch) -> None: loss = loss.item() record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) - logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") + if self.current_step % 10 == 0: + logger.info( + f"step {self.current_step} / {self.num_training_steps} | Loss: {loss}" + ) + # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) self.optimizers.step() self.lr_schedulers.step() + async def evaluate(self) -> None: + """Run evaluation on multiple datasets, one at a time. + + 1. Set models to eval mode + 2. For each eval dataset: + - Create fresh iterator (starts from epoch 0) + - Use StopAfterOneEpoch to iterate until epoch boundary. This utility + is necessary for infinite iterable dataset, since epoch boundaries are not known. + - Respect max_eval_steps cap if configured + - Record loss and step metrics (on dp rank only) + 3. Restore models to train mode + """ + + # Set models to eval mode + for model_part in self.model_parts: + model_part.eval() + + # Get DP process group for epoch synchronization + dp_mesh = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_mesh = self.parallel_dims.world_mesh.get_group("dp") + + # Evaluate each dataset sequentially + for dataset_name, val_dataloader in self.val_dataloaders.items(): + logger.info(f"=====Evaluating dataset: {dataset_name}=====") + + # Evaluation loop for this dataset + total_loss = torch.tensor(0.0, device=self.device) + num_steps = 0 + + # NOTE: Assumes batch contains field "metrics" + batch_iter = StopAfterOneEpoch( + iter=iter(val_dataloader), # Fresh iterator from epoch 0, + device=self.device, + dp_mesh=dp_mesh, + ) + + with torch.no_grad(): + for batch in batch_iter: + # Check max_eval_steps limit + if ( + self.max_eval_steps is not None + and num_steps >= self.max_eval_steps + ): + logger.info( + f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}" + ) + break + + # Move tensors to device + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(self.device) + + # Process batch + labels = batch.pop("labels") + loss = self.forward_backward(batch, labels, skip_backward=True) + total_loss += loss + num_steps += 1 + + # Log progress (rank 0 only) + if num_steps % 50 == 0: + loss_val = loss.item() + logger.info( + f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}" + ) + + # Compute average loss + avg_loss = (total_loss / max(num_steps, 1)).item() + logger.info( + f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}" + ) + # Record metrics only on DP rank 0 to avoid double counting + # record_metric aggregates across all processes via monarch + should_record = True + if dp_mesh is not None: + dp_rank = torch.distributed.get_rank(group=dp_mesh) + should_record = dp_rank == 0 + + if should_record: + record_metric( + f"evaluate/dataset_{dataset_name}_loss", + avg_loss, + Reduce.MEAN, + ) + + # Restore train mode + for model_part in self.model_parts: + model_part.train() + + logger.info("==Evaluation complete==") + @endpoint async def train(self) -> None: dataloader = iter(self.train_dataloader) @@ -289,18 +442,28 @@ async def train(self) -> None: # self.profiler.step() self.current_step += 1 - # Flush metrics - if self._rank == 0: - logger.debug(f"Flushing metrics at step {self.current_step}") - await self.mlogger.flush.call_one(global_step=self.current_step) + # Run evaluation periodically if enabled + if ( + self.validation_enabled + and self.current_step % self.eval_every_n_steps == 0 + ): + await self.evaluate() self.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, ) + # Flush metrics + if self._rank == 0: + await self.mlogger.flush.call_one(global_step=self.current_step) + # self.pbar.close() + if self.validation_enabled: + logger.info("Running final evaluation at end of training...") + await self.evaluate() + @endpoint async def cleanup(self) -> None: if self.checkpointer: diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index f7c4999bb..fb39856ae 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -31,7 +31,16 @@ training: max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +eval: + eval_every_n_steps: 50 # null = disabled + max_eval_steps: null # null = run until epoch completes + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[95%:]" parallelism: data_parallel_replicate_degree: 1 diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index d7b36fe68..75b484607 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -70,6 +70,7 @@ def __init__( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: dist.ProcessGroup | None = None, **load_dataset_kwargs, ): # Store configuration @@ -79,6 +80,7 @@ def __init__( self._model_transform = model_transform self._output_transform = output_transform self._weight = weight if weight is not None else 1.0 + self._dp_mesh = dp_mesh # Create default transform if not provided self._metric_transform = metric_transform or DefaultDatasetMetricTransform() @@ -102,6 +104,10 @@ def __init__( self._metric_transform.set_source(dataset_name) # Internal state for resumption + # _start_epoch: The epoch to start from. Updated on resume from ckpt. + # useful when doing iter(ds), which restarts dataset from original state. + self._start_epoch = 0 + # _num_epochs: updated on every dataset exhaustion self._num_epochs = 0 # Load and setup HF dataset @@ -138,12 +144,25 @@ def _setup_hf_dataset( shuffle configuration, and filtering. Called once during __init__. """ - # Distributed setup + # Extract rank/world_size from DP mesh world_size, rank = 1, 0 - if dist.is_initialized(): + if self._dp_mesh is not None: + world_size = dist.get_world_size(group=self._dp_mesh) + rank = dist.get_rank(group=self._dp_mesh) + logger.debug( + f"Using DP mesh for sharding: rank={rank}, world_size={world_size}" + ) + elif dist.is_initialized(): + # Fallback to global rank (may not respect TP/PP) world_size = dist.get_world_size() rank = dist.get_rank() + # TODO: is there a way to detect this and raise error instead? + logger.warning( + f"Using global rank for sharding: rank={rank}, world_size={world_size}. " + f"If using other types of parallelsim (CP/TP/PP), pass dp_mesh for correct sharding." + ) + # Load and shard dataset ds = load_dataset(**load_dataset_kwargs) @@ -152,7 +171,6 @@ def _setup_hf_dataset( if is_streaming: logger.warning( f"Streaming datasets were not yet tested for distributed training. " - f"split_dataset_by_node is applied, but no resharding was done manually. " f"Dataset '{self.info.name}' has " f"{getattr(ds, 'num_shards', 'unknown')} shards, and your training has {world_size} ranks." f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard" @@ -187,7 +205,7 @@ def _setup_hf_dataset( if num_shards > dataset_size: raise ValueError( f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." - f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}." + f"Please decrease one of {num_shards_per_rank=} or dataloader.num_workers={num_dataloader_workers}" ) ds = ds.to_iterable_dataset(num_shards=num_shards) @@ -218,6 +236,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: - Adds 'num_epochs' metric to track dataset progress - Yields samples indefinitely for continuous training """ + self._num_epochs = self._start_epoch while True: # Infinite iteration self._ds.set_epoch(self._num_epochs) @@ -276,7 +295,7 @@ def state_dict(self) -> dict[str, Any]: return state def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self._num_epochs = state_dict["num_epochs"] + self._start_epoch = state_dict["num_epochs"] hf_state = state_dict["hf_dataset_state"] # HF is responsible for resuming the dataset state diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 93a21b85e..69c3aa3a5 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -343,9 +343,6 @@ def _reset_packer_state(self) -> None: # exhausted: whether the dataset is exhausted self._exhausted: bool = False - # resuming: whether the packer is resuming from a checkpoint - self._resuming: bool = False - def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: """ Fills the buffer with samples from the dataset. @@ -452,15 +449,8 @@ def __iter__(self) -> Iterator[SampleDict]: if not isinstance(self.dataset, Iterable): raise TypeError("Dataset is not an iterable") - if not self._resuming: - self._reset_packer_state() - self._iterator = iter(self.dataset) - - # If resuming, the iterator must be recreated from the loaded state - if self._iterator is None: - self._iterator = iter(self.dataset) - - self._resuming = False # Consume the resume flag + self._reset_packer_state() + self._iterator = iter(self.dataset) # Main packing loop while True: @@ -502,7 +492,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise ValueError("Dataset is not stateful.") self._reset_packer_state() - self._resuming = True class TextPacker(Packer[SampleDict]): diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 00278c1e5..3820ecc97 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -7,6 +7,7 @@ from typing import Any, Callable import torch +import torch.distributed as dist from forge.data import CROSS_ENTROPY_IGNORE_IDX from forge.data.metric_transform import DefaultDatasetMetricTransform @@ -162,6 +163,7 @@ def sft_iterable_dataset( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: dist.ProcessGroup | None = None, **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ @@ -177,6 +179,7 @@ def sft_iterable_dataset( dataset_name (str | None): Name for metrics namespacing filter_fn (Callable | None): Filter function filter_kwargs (dict[str, Any] | None): Filter function kwargs + dp_mesh (dist.ProcessGroup | None): Data parallel process group for sharding (None for single process) **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset Returns: @@ -206,5 +209,6 @@ def sft_iterable_dataset( dataset_name=dataset_name, filter_fn=filter_fn, filter_kwargs=filter_kwargs, + dp_mesh=dp_mesh, **load_dataset_kwargs, ) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index be8c13857..e335c23e4 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -4,13 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Iterator, Literal, Union import torch +import torch.distributed as dist from torch.nn.attention.flex_attention import BlockMask +logger = logging.getLogger(__name__) + CROSS_ENTROPY_IGNORE_IDX = -100 Role = Literal[ @@ -213,3 +217,118 @@ def batch_to_device(batch: dict, device: torch.device) -> None: f"Tensor, or BlockMask with flexattention enabled. " f'Got key "{k}" with value of type {type(v)}' ) + + +class StopAfterOneEpoch: + """Wraps an iterator, e.g. dataloader, and stops iterating after a rank shows that an epoch has been completed. + + In distributed eval, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 + while others are already in epoch 1. To avoid hangs, all ranks *must* stop at the same time, requiring communication. + + This function minimzes this impact by fetching one batch in advance and perfoming overlapping async all_reduce. + + Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in + `forge.src.data.datasets.HfIterableDataset`. + + Args: + iter (Iterator): Iterator over dataloader batches + device (torch.device): Device for synchronizing tensors + dp_mesh (dist.ProcessGroup | None): Data parallel process group (None for single process) + """ + + def __init__( + self, + iter: Iterator, + device: torch.device, + dp_mesh: dist.ProcessGroup | None = None, + ): + self.iter = iter + self.device = device + self.dp_mesh = dp_mesh + + # Prefetch first batch for pipeline-style execution + self._next_batch = next(iter) + + # Track pending async epoch sync + self._epoch_tensor: torch.Tensor | None = None + self._pending_work: Any = None + self._should_stop = False + + def __iter__(self): + return self + + def __next__(self) -> dict: + """Get next batch from current epoch. + + Returns: + Batch dict guaranteed to be from current epoch + + Raises: + StopIteration: When epoch completes across all ranks + """ + # Check if previous epoch sync completed + if self._pending_work is not None: + self._pending_work.wait() + if self._epoch_tensor.item() > 0: + self._should_stop = True + self._pending_work = None + self._epoch_tensor = None + + if self._should_stop: + logger.debug("Eval epoch completed. Stopping data iterator.") + raise StopIteration + + # Get current batch + current_batch = self._next_batch + current_epoch = extract_epoch_from_batch(current_batch) + + # Prefetch next batch and check for epoch change + self._next_batch = next(self.iter) + next_epoch = extract_epoch_from_batch(self._next_batch) + epoch_changed = next_epoch > current_epoch + + # Start async epoch sync + if torch.distributed.is_initialized(): + self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device) + self._pending_work = torch.distributed.all_reduce( + self._epoch_tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_mesh, + async_op=True, + ) + elif epoch_changed: + # if not distributed, just update the flag directly + self._should_stop = True + + return current_batch + + +def extract_epoch_from_batch(batch: dict) -> int: + """Extract epoch number from batch metrics. Useful to detect epoch changes during validation. + + Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in + `forge.src.data.datasets.HfIterableDataset`. + + Args: + batch (dict): Batch dictionary with 'metrics' field + + Returns: + int: Max epoch number from metrics + + Raises: + ValueError: If metrics key is missing or no metric with 'num_epochs' found + """ + if "metrics" not in batch: + raise ValueError( + "Batch missing 'metrics' field. Cannot extract epoch from batch." + ) + + # Match metrics where 'num_epochs' appears in the key (handles prefixed keys like 'dataset/name/num_epochs') + epochs = [metric.value for metric in batch["metrics"] if "num_epochs" in metric.key] + if epochs: + return max(epochs) + + raise ValueError( + f"No 'num_epochs' metric found in batch. Got metrics: " + f"{[m.key for m in batch['metrics']]}" + ) diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index 8298bf1a8..76aa79142 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -270,8 +270,76 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): epoch_value == 1 for epoch_value in epoch_values ), f"Epoch values should be 1, got {epoch_values}" + def test_multiple_iter_calls_after_resume( + self, dataset_factory, small_dataset_file + ): + """Test that calling iter() multiple times after resuming restarts from checkpoint epoch. + + 1. Resume from checkpoint at epoch 2 + 2. Consume one epoch (now at epoch 3) + 3. Call iter(ds) again to create a new iterator + 4. The new iterator should restart from epoch 2 (checkpoint epoch), not 0 or 3 + + This ensures datasets can be re-iterated from their checkpoint state. + """ + dataset = dataset_factory(small_dataset_file, shuffle=False) + + # consume 2 epochs + it1 = iter(dataset) + samples = list(islice(it1, SMALL_DATASET_SIZE * 2)) + + # Save checkpoint after 2 epochs + state = dataset.state_dict() + + # Continue training for 1 more epoch on the same iterator + more_samples = list(islice(it1, SMALL_DATASET_SIZE)) + + # Create a new dataset instance and load the checkpoint + dataset2 = dataset_factory(small_dataset_file, shuffle=False) + dataset2.load_state_dict(state) + + # First iter() call should start from epoch 2 (the checkpoint epoch) + it2 = iter(dataset2) + first_iter_samples = list(islice(it2, SMALL_DATASET_SIZE)) + first_iter_epochs = [ + metric.value + for sample in first_iter_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 2 for epoch in first_iter_epochs + ), f"First iter() should start at checkpoint epoch 2, got {set(first_iter_epochs)}" + + # Consume one more epoch from the same iterator (now at epoch 3) + second_epoch_samples = list(islice(it2, SMALL_DATASET_SIZE)) + second_epoch_epochs = [ + metric.value + for sample in second_epoch_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 3 for epoch in second_epoch_epochs + ), f"Second epoch should be 3, got {set(second_epoch_epochs)}" + + # Call iter() again - it should restart from epoch 2, not continue from 4 + it3 = iter(dataset2) + new_iter_samples = list(islice(it3, SMALL_DATASET_SIZE)) + new_iter_epochs = [ + metric.value + for sample in new_iter_samples + for metric in sample["metrics"] + if "num_epochs" in metric.key + ] + assert all( + epoch == 2 for epoch in new_iter_epochs + ), f"New iter() should restart from checkpoint epoch 2, got {set(new_iter_epochs)}" + class TestDistributedHfIterableDataset(FSDPTest): + """Test HfIterableDataset with 2-GPU distributed setup.""" + @property def world_size(self) -> int: return 2 @@ -364,3 +432,124 @@ def create_loader(): finally: shutil.rmtree(temp_dir) + + +class TestDPShardingWithTP(FSDPTest): + """Test DP sharding with TP replication (4-GPU setup).""" + + @property + def world_size(self) -> int: + return 4 + + @gpu_test(gpu_count=4) + def test_dp_sharding_with_tp_replication(self): + """Verify DP sharding works correctly with TP/CP replication. + + This is a CRITICAL test that validates the core bug fix: + - Previously: Each rank got different batches (incorrect) + - Now: TP/CP ranks within same DP group get identical batches (correct) + + Setup: DP=2, TP=2 (4 GPUs total) + - DP group 0: ranks [0, 1] - should see SAME batches (TP replication) + - DP group 1: ranks [2, 3] - should see SAME batches (TP replication) + - DP group 0 vs 1: should see DIFFERENT batches (DP sharding) + + Mesh structure: + - TP rank 0 DP replicas: [0, 2] - shard across these + - TP rank 1 DP replicas: [1, 3] - shard across these + """ + import hashlib + + rank = dist.get_rank() + world_size = dist.get_world_size() + temp_dir = tempfile.mkdtemp(prefix=f"dp_tp_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with enough samples for clear sharding + # 40 samples / 2 DP groups = 20 samples per DP group + create_test_json_file(data_file, MEDIUM_DATASET_SIZE, offset=0) + + # Create DP mesh for sharding + # Key insight: Create groups across DP replicas for each TP rank + # TP rank = rank % 2, so: + # - TP rank 0: ranks [0, 2] (one from each DP group) + # - TP rank 1: ranks [1, 3] (one from each DP group) + tp_rank = rank % 2 + tp_world_size = 2 + dp_world_size = world_size // tp_world_size + + # Create DP groups for each TP rank + dp_groups = [] + for tp_r in range(tp_world_size): + # Ranks for this TP rank across DP groups + ranks = [tp_r + i * tp_world_size for i in range(dp_world_size)] + group = dist.new_group(ranks=ranks) + dp_groups.append(group) + + dp_mesh = dp_groups[tp_rank] + + # - Rank 0 (tp_rank=0) uses group [0, 2], gets rank=0 → shard 0 + # - Rank 1 (tp_rank=1) uses group [1, 3], gets rank=0 → shard 0 + # - Rank 2 (tp_rank=0) uses group [0, 2], gets rank=1 → shard 1 + # - Rank 3 (tp_rank=1) uses group [1, 3], gets rank=1 → shard 1 + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + dataset_name="dp_tp_test", + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), + num_shards_per_rank=2, + dp_mesh=dp_mesh, # CRITICAL: Pass dp_mesh for correct sharding + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, + ) + + # Collect batches and compute hashes + batches = list(islice(iter(dataloader), 5)) + batch_hashes = [] + for batch in batches: + # Hash the batch IDs to verify identity/difference + batch_ids = batch["id"].cpu().tolist() + batch_hash = hashlib.md5(str(batch_ids).encode()).hexdigest() + batch_hashes.append(batch_hash) + + # Gather hashes from all ranks for comparison + gathered_hashes = [None] * world_size + dist.all_gather_object(gathered_hashes, batch_hashes) + + if rank == 0: + # Verify TP replication within DP groups + # Ranks 0 and 1 should have identical hashes (same DP group) + assert gathered_hashes[0] == gathered_hashes[1], ( + f"Ranks 0 and 1 (same DP group) should see identical batches!\n" + f"Rank 0 hashes: {gathered_hashes[0][:3]}...\n" + f"Rank 1 hashes: {gathered_hashes[1][:3]}..." + ) + + # Ranks 2 and 3 should have identical hashes (same DP group) + assert gathered_hashes[2] == gathered_hashes[3], ( + f"Ranks 2 and 3 (same DP group) should see identical batches!\n" + f"Rank 2 hashes: {gathered_hashes[2][:3]}...\n" + f"Rank 3 hashes: {gathered_hashes[3][:3]}..." + ) + + # Verify DP sharding across groups + # Ranks 0/1 should see DIFFERENT batches from ranks 2/3 + assert gathered_hashes[0] != gathered_hashes[2], ( + f"Ranks 0 and 2 (different DP groups) should see different batches!\n" + f"DP group 0 hashes: {gathered_hashes[0][:3]}...\n" + f"DP group 1 hashes: {gathered_hashes[2][:3]}..." + ) + + dist.barrier() + + finally: + shutil.rmtree(temp_dir) diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 56cd5ff02..1c6c4906f 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -949,3 +949,49 @@ def create_loader(): # Verify that checkpointing and resumption work assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint assert len(result["resumed_batches"]) == steps_after_checkpoint + + def test_iter_restart_determinism(self, dataset_factory): + """Test that calling iter() multiple times produces deterministic results. + + This is critical for evaluation: each eval run should start from the + same state (epoch 0, step 0) regardless of previous iterations. + """ + samples = [ + {"tokens": [0] * 3}, + {"tokens": [1] * 2}, + {"tokens": [2] * 4}, + ] + target_tokens_per_pack = 6 + + # Create packed dataset + dataset = dataset_factory(samples) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = PackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=1, + ) + + # First iteration - get first 2 packs + iter1 = iter(packed_dataset) + packs_iter1 = list(islice(iter1, 2)) + + # Second iteration - should get same first 2 packs + iter2 = iter(packed_dataset) + packs_iter2 = list(islice(iter2, 2)) + + # Verify both iterations produce identical packs + assert len(packs_iter1) == len(packs_iter2) == 2 + + for i, (pack1, pack2) in enumerate(zip(packs_iter1, packs_iter2)): + torch.testing.assert_close( + pack1["tokens"], + pack2["tokens"], + msg=f"Pack {i}: tokens mismatch between iterations", + ) + torch.testing.assert_close( + pack1["document_ids"], + pack2["document_ids"], + msg=f"Pack {i}: document_ids mismatch between iterations", + ) diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py new file mode 100644 index 000000000..d0deaf86d --- /dev/null +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for StopAfterOneEpoch iterator and extract_epoch_from_batch helper.""" +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from forge.data.datasets import HfIterableDataset + +from forge.data.utils import extract_epoch_from_batch, StopAfterOneEpoch +from forge.observability.metrics import Metric, Reduce +from torch.testing._internal.common_fsdp import FSDPTest +from torchdata.stateful_dataloader import StatefulDataLoader + +from tests.test_utils import gpu_test + + +def create_test_json_file(path: Path, num_samples: int) -> None: + """Create test data file with simple samples.""" + with open(path, "w") as f: + for i in range(num_samples): + f.write(f'{{"id": {i}, "tokens": [{i}, {i+1}]}}\n') + + +def simple_collate(batch): + """Simple collate function that mimics collate_packed behavior. + + Stacks tensors, extends metrics list, keeps other fields as lists. + """ + collated = {} + for key in batch[0].keys(): + if isinstance(batch[0][key], torch.Tensor): + collated[key] = torch.stack([sample[key] for sample in batch], dim=0) + elif key == "metrics": + # Extend all metrics into a single list + collated[key] = [] + for sample in batch: + collated[key].extend(sample[key]) + else: + collated[key] = [sample[key] for sample in batch] + return collated + + +class TestExtractEpochFromBatch: + """Test extract_epoch_from_batch helper function.""" + + def test_extract_epoch_from_batch_success(self): + """Test extracting epoch from valid batch with metrics.""" + batch = { + "tokens": torch.tensor([1, 2, 3]), + "metrics": [ + Metric(key="dataset/test/num_epochs", value=2, reduction=Reduce.MAX), + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ], + } + epoch = extract_epoch_from_batch(batch) + assert epoch == 2 + + def test_extract_epoch_missing_metrics_field(self): + """Test error when batch has no 'metrics' field.""" + batch = {"tokens": torch.tensor([1, 2, 3])} + with pytest.raises(ValueError, match="Batch missing 'metrics' field"): + extract_epoch_from_batch(batch) + + def test_extract_epoch_no_num_epochs_metric(self): + """Test error when no num_epochs metric found.""" + batch = { + "metrics": [ + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ] + } + with pytest.raises(ValueError, match="No 'num_epochs' metric found"): + extract_epoch_from_batch(batch) + + +class TestStopAfterOneEpochSingleProcess: + """Test StopAfterOneEpoch in single-process mode (no distributed).""" + + def test_stop_after_one_epoch(self, tmp_path): + """Verify iterator stops after exactly one epoch completes.""" + # Create small dataset (10 samples) + data_file = tmp_path / "data.json" + create_test_json_file(data_file, num_samples=10) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader( + dataset, batch_size=2, collate_fn=simple_collate + ) + + # Wrap with StopAfterOneEpoch + batch_iter = StopAfterOneEpoch( + iter=iter(dataloader), + device=torch.device("cpu"), + dp_mesh=None, + ) + + # Collect all batches until StopIteration + batches = [] + for batch in batch_iter: + batches.append(batch) + # Verify all batches are from epoch 0 + epoch = extract_epoch_from_batch(batch) + assert epoch == 0, f"Expected epoch 0, got {epoch}" + + # Should have consumed exactly one epoch (5 batches of size 2) + assert len(batches) == 5 + + +class TestStopAfterOneEpochDistributed(FSDPTest): + """Test StopAfterOneEpoch with distributed synchronization.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_epoch_sync_across_ranks(self): + """Verify all ranks stop when any rank detects epoch change.""" + import shutil + import tempfile + + rank = dist.get_rank() + temp_dir = tempfile.mkdtemp(prefix=f"stop_epoch_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with 20 samples, split across 2 ranks (10 each) + create_test_json_file(data_file, num_samples=20) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader( + dataset, batch_size=2, collate_fn=simple_collate + ) + + # Get DP process group (use global group for this test) + dp_mesh = dist.group.WORLD + + batch_iter = StopAfterOneEpoch( + iter=iter(dataloader), + device=torch.device("cuda"), + dp_mesh=dp_mesh, + ) + + # Collect batches + batches = [] + for batch in batch_iter: + batches.append(batch) + # All should be epoch 0 + assert extract_epoch_from_batch(batch) == 0 + + # All ranks should have processed exactly one epoch + # Since dataset is split across ranks, each rank gets 10 samples = 5 batches + assert ( + len(batches) == 5 + ), f"Rank {rank} expected 5 batches, got {len(batches)}" + + # Synchronize to ensure both ranks completed + dist.barrier() + + finally: + shutil.rmtree(temp_dir)