-
Notifications
You must be signed in to change notification settings - Fork 51
[wip][SFT Eval ] Add eval to SFT script #536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
db35980
801a454
ebd4ac1
0e4bdc3
d9ea30e
95539c5
0919f5b
2b8cfbf
63fabb7
dc4b37b
eb6a3a1
aadd15a
520c9d3
a0dcc98
47829e6
d2a3502
5563735
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from forge.data.datasets.packed import PackedDataset, TextPacker | ||
| from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset | ||
| from forge.data.tokenizer import HuggingFaceModelTokenizer | ||
| from forge.data.utils import StopAfterOneEpoch | ||
| from forge.observability import get_or_create_metric_logger, record_metric, Reduce | ||
| from forge.util.config import parse | ||
|
|
||
|
|
@@ -81,6 +82,7 @@ def __init__(self, config: DictConfig): | |
| self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||
| self._rank = current_rank().rank | ||
| self._size = math.prod(current_size().values()) | ||
|
|
||
| self._init_dist() | ||
| super().__init__(job_config) | ||
|
|
||
|
|
@@ -122,28 +124,67 @@ def record_batch_metrics(self, data_metrics: list): | |
|
|
||
| @endpoint | ||
| async def setup(self): | ||
| self.train_dataloader = self.setup_data() | ||
|
|
||
| # metric logger | ||
| self.mlogger = await self.setup_metric_logger() | ||
|
|
||
| # self.train_dataloader = self.setup_data( | ||
| # self.train_config.train_dataset_config, | ||
| # self.train_config.train_dataloader_config, | ||
| # self.train_config.packing_config, | ||
| # ) | ||
| # self.val_dataloader = self.setup_data( | ||
| # self.train_config.val_dataset_config, | ||
| # self.train_config.val_dataloader_config, | ||
| # self.train_config.packing_config, | ||
| # ) | ||
| # Load training datasets | ||
| logger.info("Setting training datasets") | ||
| train_datasets_config = self.job_config.training.datasets | ||
| self.train_dataloader = self.setup_data(train_datasets_config) | ||
|
|
||
| # Load eval datasets | ||
| eval_config = self.job_config.get("eval", {}) | ||
| self.val_dataloaders = {} | ||
| self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) | ||
| max_eval_steps = eval_config.get("max_eval_steps", None) | ||
| self.max_eval_steps = ( | ||
| max_eval_steps if max_eval_steps and max_eval_steps > 0 else None | ||
| ) | ||
| self.validation_enabled = ( | ||
| self.eval_every_n_steps is not None and self.eval_every_n_steps > 0 | ||
| ) | ||
| if self.validation_enabled: | ||
| logger.info("Setting eval datasets") | ||
| self.eval_datasets_config = eval_config.datasets | ||
|
|
||
| for i, dataset_config in enumerate(self.eval_datasets_config): | ||
| ds_name = dataset_config.get("dataset_name", i) | ||
|
|
||
| # TODO: Support separate eval batch size from config (eval.local_batch_size) | ||
| dataloader = self.setup_data([dataset_config]) | ||
| self.val_dataloaders[ds_name] = dataloader | ||
|
|
||
| # TODO: confirm that this is working properly | ||
| # Should also use load, not dcp_load | ||
| self.checkpointer.load(step=self.current_step) | ||
|
|
||
| # self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
| # self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
|
||
| def setup_data(self): | ||
| print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||
| def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: | ||
| """Instantiates datasets and returns a StatefulDataLoader. | ||
|
|
||
| Args: | ||
| dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`. | ||
|
|
||
| Returns: | ||
| StatefulDataLoader | ||
|
|
||
| Raises: | ||
| ValueError: If multiple datasets provided (not yet supported) | ||
| """ | ||
| # TODO felipemello: Currently only support single dataset | ||
| if len(dataset_configs) > 1: | ||
| raise ValueError( | ||
| f"Multiple training datasets not supported yet. " | ||
| f"Got {len(dataset_configs)} datasets. " | ||
| ) | ||
|
|
||
| dataset_config = dataset_configs[0] | ||
|
|
||
| # TODO: Evaluate if tokenizers should be created once and shared for every dataset | ||
| # Load tokenizer | ||
| tokenizer = HuggingFaceModelTokenizer( | ||
| tokenizer_json_path=os.path.join( | ||
| self.job_config.model.hf_assets_path, "tokenizer.json" | ||
|
|
@@ -165,18 +206,26 @@ def setup_data(self): | |
| ), | ||
| ) | ||
|
|
||
| # Get DP mesh for data sharding | ||
| dp_mesh = None | ||
| if self.parallel_dims is not None and self.parallel_dims.dp_enabled: | ||
| dp_mesh = self.parallel_dims.world_mesh.get_group("dp") | ||
|
|
||
| # Pass config directly to dataset constructor | ||
| dataset = sft_iterable_dataset( | ||
| model_transform=tokenizer, | ||
| message_transform=AlpacaToMessages(), | ||
| path="yahma/alpaca-cleaned", | ||
| split="train", | ||
| dp_mesh=dp_mesh, | ||
| **dataset_config, | ||
| ) | ||
|
|
||
| packer = TextPacker(padding_idx=0) | ||
| dataset = PackedDataset( | ||
| dataset=dataset, | ||
| packer=packer, | ||
| target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model | ||
| ) | ||
|
|
||
| dataloader = StatefulDataLoader( | ||
| dataset=dataset, | ||
| batch_size=self.job_config.training.local_batch_size, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using drop_last = True here? if the dataset size is not divisible by batch_size * world_size, some ranks will have fewer batches which could lead to potential deadlock. Thoughts?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is true for mapstyle, but for iterative datasets, this is a no-op, so i thought it was deceiving to have it there, i.e. "why do we need the StopAfterOneEpoch utility if we already have drop_last=True". One can make the argument: "what if the user implements their own dataset class as map style?". Our PackedDataset and InterleavedDataset would still be iterable datasets, so the input to the dataloader would always be an iterable. Let me know if that makes sense. |
||
|
|
@@ -192,8 +241,12 @@ def setup_data(self): | |
| return dataloader | ||
|
|
||
| def forward_backward( | ||
| self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, | ||
| skip_backward: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Forward pass with optional backward.""" | ||
| model_parts = self.model_parts | ||
| parallel_dims = self.parallel_dims | ||
|
|
||
|
|
@@ -214,6 +267,7 @@ def forward_backward( | |
|
|
||
| if parallel_dims.pp_enabled: | ||
| # Pipeline Parallel forward / backward inside step() call | ||
| # Note: PP backward only happens if not in torch.no_grad() context | ||
| with self.train_context(optional_context_parallel_ctx): | ||
| targets, losses = ( | ||
| (labels, []) if self.pp_has_last_stage else (None, None) | ||
|
|
@@ -235,15 +289,18 @@ def forward_backward( | |
| else torch.tensor([-1.0], device=self.device) | ||
| ) | ||
| else: | ||
| # Non-PP forward / backward | ||
| # Non-PP forward / backward - must happen inside same context | ||
| with self.train_context(optional_context_parallel_ctx): | ||
| assert len(model_parts) == 1 | ||
| with self.maybe_enable_amp: | ||
| pred = model_parts[0](inputs) | ||
| loss = self.loss_fn(pred, labels) | ||
| # need to free to before bwd to avoid peaking memory | ||
| del pred | ||
| loss.backward() | ||
|
|
||
| # Only run backward if requested. Useful for eval. | ||
| if not skip_backward: | ||
| loss.backward() | ||
|
|
||
| return loss | ||
|
|
||
|
|
@@ -259,12 +316,108 @@ def train_step(self, batch) -> None: | |
| loss = loss.item() | ||
|
|
||
| record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) | ||
| logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||
| if self.current_step % 10 == 0: | ||
| logger.info( | ||
| f"step {self.current_step} / {self.num_training_steps} | Loss: {loss}" | ||
| ) | ||
|
|
||
| # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
| # self.pbar.update(1) | ||
| self.optimizers.step() | ||
| self.lr_schedulers.step() | ||
|
|
||
| async def evaluate(self) -> None: | ||
| """Run evaluation on multiple datasets, one at a time. | ||
|
|
||
| 1. Set models to eval mode | ||
| 2. For each eval dataset: | ||
| - Create fresh iterator (starts from epoch 0) | ||
| - Use StopAfterOneEpoch to iterate until epoch boundary. This utility | ||
| is necessary for infinite iterable dataset, since epoch boundaries are not known. | ||
| - Respect max_eval_steps cap if configured | ||
| - Record loss and step metrics (on dp rank only) | ||
| 3. Restore models to train mode | ||
| """ | ||
|
|
||
| # Set models to eval mode | ||
| for model_part in self.model_parts: | ||
| model_part.eval() | ||
|
|
||
| # Get DP process group for epoch synchronization | ||
| dp_mesh = None | ||
| if self.parallel_dims is not None and self.parallel_dims.dp_enabled: | ||
| dp_mesh = self.parallel_dims.world_mesh.get_group("dp") | ||
|
|
||
| # Evaluate each dataset sequentially | ||
| for dataset_name, val_dataloader in self.val_dataloaders.items(): | ||
| logger.info(f"=====Evaluating dataset: {dataset_name}=====") | ||
|
|
||
| # Evaluation loop for this dataset | ||
| total_loss = torch.tensor(0.0, device=self.device) | ||
| num_steps = 0 | ||
|
|
||
| # NOTE: Assumes batch contains field "metrics" | ||
| batch_iter = StopAfterOneEpoch( | ||
| iter=iter(val_dataloader), # Fresh iterator from epoch 0, | ||
| device=self.device, | ||
| dp_mesh=dp_mesh, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| for batch in batch_iter: | ||
| # Check max_eval_steps limit | ||
| if ( | ||
| self.max_eval_steps is not None | ||
| and num_steps >= self.max_eval_steps | ||
| ): | ||
| logger.info( | ||
| f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}" | ||
| ) | ||
| break | ||
|
|
||
| # Move tensors to device | ||
| for key, value in batch.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| batch[key] = value.to(self.device) | ||
|
|
||
| # Process batch | ||
| labels = batch.pop("labels") | ||
| loss = self.forward_backward(batch, labels, skip_backward=True) | ||
| total_loss += loss | ||
| num_steps += 1 | ||
|
|
||
| # Log progress (rank 0 only) | ||
| if num_steps % 50 == 0: | ||
| loss_val = loss.item() | ||
| logger.info( | ||
| f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}" | ||
| ) | ||
|
|
||
| # Compute average loss | ||
| avg_loss = (total_loss / max(num_steps, 1)).item() | ||
| logger.info( | ||
| f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}" | ||
| ) | ||
| # Record metrics only on DP rank 0 to avoid double counting | ||
| # record_metric aggregates across all processes via monarch | ||
| should_record = True | ||
| if dp_mesh is not None: | ||
| dp_rank = torch.distributed.get_rank(group=dp_mesh) | ||
| should_record = dp_rank == 0 | ||
|
Comment on lines
+403
to
+406
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic here is wrong. we should record for every dp, and have checks on other types. Will do it on monday. |
||
|
|
||
| if should_record: | ||
| record_metric( | ||
| f"evaluate/dataset_{dataset_name}_loss", | ||
| avg_loss, | ||
| Reduce.MEAN, | ||
| ) | ||
|
|
||
| # Restore train mode | ||
| for model_part in self.model_parts: | ||
| model_part.train() | ||
|
|
||
| logger.info("==Evaluation complete==") | ||
|
|
||
| @endpoint | ||
| async def train(self) -> None: | ||
| dataloader = iter(self.train_dataloader) | ||
|
|
@@ -289,18 +442,28 @@ async def train(self) -> None: | |
| # self.profiler.step() | ||
| self.current_step += 1 | ||
|
|
||
| # Flush metrics | ||
| if self._rank == 0: | ||
| logger.debug(f"Flushing metrics at step {self.current_step}") | ||
| await self.mlogger.flush.call_one(global_step=self.current_step) | ||
| # Run evaluation periodically if enabled | ||
| if ( | ||
| self.validation_enabled | ||
| and self.current_step % self.eval_every_n_steps == 0 | ||
| ): | ||
| await self.evaluate() | ||
|
|
||
| self.checkpointer.save( | ||
| curr_step=self.current_step, | ||
| last_step=self.current_step == self.num_training_steps, | ||
| ) | ||
|
|
||
| # Flush metrics | ||
| if self._rank == 0: | ||
| await self.mlogger.flush.call_one(global_step=self.current_step) | ||
|
|
||
| # self.pbar.close() | ||
|
|
||
| if self.validation_enabled: | ||
| logger.info("Running final evaluation at end of training...") | ||
| await self.evaluate() | ||
|
|
||
| @endpoint | ||
| async def cleanup(self) -> None: | ||
| if self.checkpointer: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For eval_every_n_steps, is there a check to break when we exhaust the steps? If we don't have the epoch metric, shouldn't this be the metric to break the eval loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do:
So its whichever comes first , one epoch or self.max_eval_steps.
Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py
Worst case, we can add checks if someone complain.
I wanted to avoid adding complexity adding more if/else here.
wdyt?