Skip to content

Conversation

@felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Nov 6, 2025

Continuation of the work by @HosseinKaviani-H in #404

TDLR:

  • Adds eval loop to SFT
  • Adds non-blocking equivalent to drop_last and stops eval after 1 epoch.
  • Adds config for multi eval dataset and multi train dataset (train doest not support >1 ds yet, but it will in a different PR)
  • [FIX] fixes and adds new unit tests covering blind spots in our dataset for when we have TP/CP
  • [FIX] Enables dataset to reset by doing iter(dataset)
image

Context:

In forge we use infinite iterable datasets (not map-style). There are advantages to it, such as:
a) Streaming / Not holding the entire dataset in memory;
b) Easy manipulation of data, e.g. while True: do_something(next(iter))
c) dataset can be generated on the fly, e.g. replay buffer

However, there are also challenges:
a) We do not know the size of the dataset in advance;
b) We don't know epoch boundaries (the dataset resets automatically on exhaustion. This is done so that we don't have to deal with potential hangs from different ranks not getting enough samples when the dataset is exhausted)

Original problem:

For validation, we want to run only 1 epoch. In map-style datasets, this is easy: i) break after one iteration over the loop; ii) set dataloader(drop_last=True) to avoid hangs;

As discussed above, this is not possible with infinite iterable datasets.

To identify epoch boundaries, our dataset implementation returns a Metric num_epochs. We can use to to easily verify if we started a new epoch, and stop there.

for batch in dataloader:
	if any(metric.value>0 for metric in batch["metrics"] if metric.key=="num_epochs"):
		break

However, in a distributed setting, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 while others are already in epoch 1.

To avoid hangs, all ranks must stop at the same time. This means that we need to do some sort of all_reduce to know if at least one rank has seen epoch==1, introducing communication overhead and blocking the forward pass.

Solution:

This PR implements StopAfterOneEpoch(dataloader), that fetches one batch in advance and does the all_reduce async, overlapping communications. The utility elegantly abstracts it away from the user.

Issues found/fixed during this implementation:

  • HfIterableDataset sharded the data on all ranks, not only dp_ranks. This means that TP ranks were getting different batches, instead of repeated.
  • The datasets had to be reset after every eval. Some changes had to be made so that doing iter(dataset) provided a fresh new iter with the original state. This is much faster than creating a new dataset on every eval loop;

Hossein Kavianihamedani and others added 5 commits October 27, 2025 11:51
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation
- Update main.py to support multi-dataset configuration and evaluation
  - Add validation config settings (enabled, eval_interval, eval_steps)
  - Refactor setup() to support dataset_val.datasets structure
  - Add unified forward() method with compute_gradients flag
  - Add evaluate() method that calls run_evaluation()
- Update llama3_8b.yaml with multi-dataset configuration
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name'
- Simplify epoch tracking: compare consecutive batches instead of tracking from start
- Remove starting_epoch variable - no longer needed
- Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment
- Add better logging for epoch changes and tracking status
- Epoch sync now works correctly with the actual metric structure
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 6, 2025
@felipemello1 felipemello1 marked this pull request as draft November 6, 2025 21:43
@felipemello1 felipemello1 marked this pull request as ready for review November 7, 2025 19:16
@felipemello1 felipemello1 changed the title [WIP][SFT Eval ] Add eval to SFT script [SFT Eval ] Add eval to SFT script Nov 7, 2025
@felipemello1 felipemello1 changed the title [SFT Eval ] Add eval to SFT script [wip][SFT Eval ] Add eval to SFT script Nov 7, 2025
dataset_name: str | None = None,
filter_fn: Callable | None = None,
filter_kwargs: dict[str, Any] | None = None,
dp_mesh: dist.ProcessGroup | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

With iterable dataset, we shard the dataset and split it across ranks. This is done by calling:

ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)

Before: I was using the global rank

Problem: If using TP/CP/PP etc, each rank would get a different data point. This is wrong. We need to split it per dp rank. All ranks within the same dp should get the same data point.

Solution: pass dp_mesh

# Internal state for resumption
# _start_epoch: The epoch to start from. Updated on resume from ckpt.
# useful when doing iter(ds), which restarts dataset from original state.
self._start_epoch = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

We want to restart the dataset on every eval. We can do this by calling iter(ds), which is cheap.

In def __iter__ we then do:

self._num_epochs = self._start_epoch
self._ds.set_epoch(self._num_epochs) #used to preserve shuffle order

In other words: we need to add self._start_epoch so we know where to reset to. You may ask "why not just always set it to 0?". Because for training, we may have resumed from checkpoint and want iter(ds) start from elsewhere. Thats why we do:

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._start_epoch = state_dict["num_epochs"]

Comment on lines +452 to +453
self._reset_packer_state()
self._iterator = iter(self.dataset)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

Similar to hf_dataset.py changes, when we call iter(ds), we always want it to restart from its original state. There is no need for the extra checks that were here.

Copy link
Contributor

@HosseinKaviani-H HosseinKaviani-H left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me. Thanks for the PR. Just had a few minor comments/questions.

P.S. Do you have test results on TP/CP where ranks get different samples?

self.max_eval_steps = (
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
)
self.validation_enabled = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eval_every_n_steps, is there a check to break when we exhaust the steps? If we don't have the epoch metric, shouldn't this be the metric to break the eval loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do:


for batch in StopAfterOneEpoch(val_dataloader):
  # Check max_eval_steps limit
  if (
      self.max_eval_steps is not None
      and num_steps >= self.max_eval_steps
  ):
      break

So its whichever comes first , one epoch or self.max_eval_steps.

Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py

Worst case, we can add checks if someone complain.

I wanted to avoid adding complexity adding more if/else here.

wdyt?


dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using drop_last = True here? if the dataset size is not divisible by batch_size * world_size, some ranks will have fewer batches which could lead to potential deadlock. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is true for mapstyle, but for iterative datasets, this is a no-op, so i thought it was deceiving to have it there, i.e. "why do we need the StopAfterOneEpoch utility if we already have drop_last=True".

One can make the argument: "what if the user implements their own dataset class as map style?". Our PackedDataset and InterleavedDataset would still be iterable datasets, so the input to the dataloader would always be an iterable.

Let me know if that makes sense.

self.max_eval_steps = (
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
)
self.validation_enabled = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do:


for batch in StopAfterOneEpoch(val_dataloader):
  # Check max_eval_steps limit
  if (
      self.max_eval_steps is not None
      and num_steps >= self.max_eval_steps
  ):
      break

So its whichever comes first , one epoch or self.max_eval_steps.

Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py

Worst case, we can add checks if someone complain.

I wanted to avoid adding complexity adding more if/else here.

wdyt?

Comment on lines +403 to +406
should_record = True
if dp_mesh is not None:
dp_rank = torch.distributed.get_rank(group=dp_mesh)
should_record = dp_rank == 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic here is wrong. we should record for every dp, and have checks on other types. Will do it on monday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants