-
Notifications
You must be signed in to change notification settings - Fork 51
[wip][SFT Eval ] Add eval to SFT script #536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation - Update main.py to support multi-dataset configuration and evaluation - Add validation config settings (enabled, eval_interval, eval_steps) - Refactor setup() to support dataset_val.datasets structure - Add unified forward() method with compute_gradients flag - Add evaluate() method that calls run_evaluation() - Update llama3_8b.yaml with multi-dataset configuration
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name' - Simplify epoch tracking: compare consecutive batches instead of tracking from start - Remove starting_epoch variable - no longer needed - Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment - Add better logging for epoch changes and tracking status - Epoch sync now works correctly with the actual metric structure
| dataset_name: str | None = None, | ||
| filter_fn: Callable | None = None, | ||
| filter_kwargs: dict[str, Any] | None = None, | ||
| dp_mesh: dist.ProcessGroup | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
With iterable dataset, we shard the dataset and split it across ranks. This is done by calling:
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
Before: I was using the global rank
Problem: If using TP/CP/PP etc, each rank would get a different data point. This is wrong. We need to split it per dp rank. All ranks within the same dp should get the same data point.
Solution: pass dp_mesh
| # Internal state for resumption | ||
| # _start_epoch: The epoch to start from. Updated on resume from ckpt. | ||
| # useful when doing iter(ds), which restarts dataset from original state. | ||
| self._start_epoch = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
We want to restart the dataset on every eval. We can do this by calling iter(ds), which is cheap.
In def __iter__ we then do:
self._num_epochs = self._start_epoch
self._ds.set_epoch(self._num_epochs) #used to preserve shuffle order
In other words: we need to add self._start_epoch so we know where to reset to. You may ask "why not just always set it to 0?". Because for training, we may have resumed from checkpoint and want iter(ds) start from elsewhere. Thats why we do:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._start_epoch = state_dict["num_epochs"]
| self._reset_packer_state() | ||
| self._iterator = iter(self.dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
Similar to hf_dataset.py changes, when we call iter(ds), we always want it to restart from its original state. There is no need for the extra checks that were here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me. Thanks for the PR. Just had a few minor comments/questions.
P.S. Do you have test results on TP/CP where ranks get different samples?
| self.max_eval_steps = ( | ||
| max_eval_steps if max_eval_steps and max_eval_steps > 0 else None | ||
| ) | ||
| self.validation_enabled = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For eval_every_n_steps, is there a check to break when we exhaust the steps? If we don't have the epoch metric, shouldn't this be the metric to break the eval loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do:
for batch in StopAfterOneEpoch(val_dataloader):
# Check max_eval_steps limit
if (
self.max_eval_steps is not None
and num_steps >= self.max_eval_steps
):
break
So its whichever comes first , one epoch or self.max_eval_steps.
Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py
Worst case, we can add checks if someone complain.
I wanted to avoid adding complexity adding more if/else here.
wdyt?
|
|
||
| dataloader = StatefulDataLoader( | ||
| dataset=dataset, | ||
| batch_size=self.job_config.training.local_batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using drop_last = True here? if the dataset size is not divisible by batch_size * world_size, some ranks will have fewer batches which could lead to potential deadlock. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is true for mapstyle, but for iterative datasets, this is a no-op, so i thought it was deceiving to have it there, i.e. "why do we need the StopAfterOneEpoch utility if we already have drop_last=True".
One can make the argument: "what if the user implements their own dataset class as map style?". Our PackedDataset and InterleavedDataset would still be iterable datasets, so the input to the dataloader would always be an iterable.
Let me know if that makes sense.
| self.max_eval_steps = ( | ||
| max_eval_steps if max_eval_steps and max_eval_steps > 0 else None | ||
| ) | ||
| self.validation_enabled = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do:
for batch in StopAfterOneEpoch(val_dataloader):
# Check max_eval_steps limit
if (
self.max_eval_steps is not None
and num_steps >= self.max_eval_steps
):
break
So its whichever comes first , one epoch or self.max_eval_steps.
Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py
Worst case, we can add checks if someone complain.
I wanted to avoid adding complexity adding more if/else here.
wdyt?
| should_record = True | ||
| if dp_mesh is not None: | ||
| dp_rank = torch.distributed.get_rank(group=dp_mesh) | ||
| should_record = dp_rank == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic here is wrong. we should record for every dp, and have checks on other types. Will do it on monday.
Continuation of the work by @HosseinKaviani-H in #404
TDLR:
drop_lastand stops eval after 1 epoch.Context:
In forge we use infinite iterable datasets (not map-style). There are advantages to it, such as:
a) Streaming / Not holding the entire dataset in memory;
b) Easy manipulation of data, e.g.
while True: do_something(next(iter))c) dataset can be generated on the fly, e.g. replay buffer
However, there are also challenges:
a) We do not know the size of the dataset in advance;
b) We don't know epoch boundaries (the dataset resets automatically on exhaustion. This is done so that we don't have to deal with potential hangs from different ranks not getting enough samples when the dataset is exhausted)
Original problem:
For validation, we want to run only 1 epoch. In map-style datasets, this is easy: i) break after one iteration over the loop; ii) set dataloader(drop_last=True) to avoid hangs;
As discussed above, this is not possible with infinite iterable datasets.
To identify epoch boundaries, our dataset implementation returns a Metric
num_epochs. We can use to to easily verify if we started a new epoch, and stop there.However, in a distributed setting, we may have
len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 while others are already in epoch 1.To avoid hangs, all ranks must stop at the same time. This means that we need to do some sort of
all_reduceto know if at least one rank has seenepoch==1, introducing communication overhead and blocking the forward pass.Solution:
This PR implements
StopAfterOneEpoch(dataloader), that fetches one batch in advance and does theall_reduceasync, overlapping communications. The utility elegantly abstracts it away from the user.Issues found/fixed during this implementation:
HfIterableDatasetsharded the data on all ranks, not only dp_ranks. This means that TP ranks were getting different batches, instead of repeated.