-
Notifications
You must be signed in to change notification settings - Fork 52
[wip][SFT Eval ] Add eval to SFT script #536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
db35980
801a454
ebd4ac1
0e4bdc3
d9ea30e
95539c5
0919f5b
2b8cfbf
63fabb7
dc4b37b
eb6a3a1
aadd15a
520c9d3
a0dcc98
47829e6
d2a3502
5563735
cd1cb62
1a8d459
09581d7
f40b5c3
b37df3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| """ | ||
|
|
||
| import asyncio | ||
|
|
||
| import contextlib | ||
| import logging | ||
| import math | ||
| import os | ||
|
|
@@ -27,6 +27,7 @@ | |
| from forge.data.datasets.packed import PackedDataset, TextPacker | ||
| from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset | ||
| from forge.data.tokenizer import HuggingFaceModelTokenizer | ||
| from forge.data.utils import StopAfterOneEpoch | ||
| from forge.observability import get_or_create_metric_logger, record_metric, Reduce | ||
| from forge.util.config import parse | ||
|
|
||
|
|
@@ -81,9 +82,19 @@ def __init__(self, config: DictConfig): | |
| self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||
| self._rank = current_rank().rank | ||
| self._size = math.prod(current_size().values()) | ||
|
|
||
| self._init_dist() | ||
| super().__init__(job_config) | ||
|
|
||
| # For Pipeline Parallelism (PP): Only the last PP stage computes the actual loss | ||
| # For non-PP setups: All ranks compute loss | ||
| self.rank_should_record_loss = True | ||
| if hasattr(self, "pp_has_last_stage") and not self.pp_has_last_stage: | ||
| self.rank_should_record_loss = False | ||
|
|
||
| # Logging frequency | ||
| self.log_every_n_steps = self.job_config.get("log_every_n_steps", 10) | ||
|
|
||
| def _init_dist(self): | ||
| """Initializes torch distributed. | ||
|
|
||
|
|
@@ -122,28 +133,67 @@ def record_batch_metrics(self, data_metrics: list): | |
|
|
||
| @endpoint | ||
| async def setup(self): | ||
| self.train_dataloader = self.setup_data() | ||
|
|
||
| # metric logger | ||
| self.mlogger = await self.setup_metric_logger() | ||
|
|
||
| # self.train_dataloader = self.setup_data( | ||
| # self.train_config.train_dataset_config, | ||
| # self.train_config.train_dataloader_config, | ||
| # self.train_config.packing_config, | ||
| # ) | ||
| # self.val_dataloader = self.setup_data( | ||
| # self.train_config.val_dataset_config, | ||
| # self.train_config.val_dataloader_config, | ||
| # self.train_config.packing_config, | ||
| # ) | ||
| # Load training datasets | ||
| logger.info("Setting training datasets") | ||
| train_datasets_config = self.job_config.training.datasets | ||
| self.train_dataloader = self.setup_data(train_datasets_config) | ||
|
|
||
| # Load eval datasets | ||
| eval_config = self.job_config.get("eval", {}) | ||
| self.val_dataloaders = {} | ||
| self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No getter |
||
| max_eval_steps = eval_config.get("max_eval_steps", None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No getter |
||
| self.max_eval_steps = ( | ||
| max_eval_steps if max_eval_steps and max_eval_steps > 0 else None | ||
| ) | ||
| self.validation_enabled = ( | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.eval_every_n_steps is not None and self.eval_every_n_steps > 0 | ||
| ) | ||
| if self.validation_enabled: | ||
| logger.info("Setting eval datasets") | ||
| self.eval_datasets_config = eval_config.datasets | ||
|
|
||
| for i, dataset_config in enumerate(self.eval_datasets_config): | ||
| ds_name = dataset_config.get("dataset_name", i) | ||
|
|
||
| # TODO: Support separate eval batch size from config (eval.local_batch_size) | ||
| dataloader = self.setup_data([dataset_config]) | ||
| self.val_dataloaders[ds_name] = dataloader | ||
|
|
||
| # TODO: confirm that this is working properly | ||
| # Should also use load, not dcp_load | ||
| self.checkpointer.load(step=self.current_step) | ||
|
|
||
| # self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
| # self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
|
||
| def setup_data(self): | ||
| print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||
| def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: | ||
| """Instantiates datasets and returns a StatefulDataLoader. | ||
|
|
||
| Args: | ||
| dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`. | ||
|
|
||
| Returns: | ||
| StatefulDataLoader | ||
|
|
||
| Raises: | ||
| ValueError: If multiple datasets provided (not yet supported) | ||
| """ | ||
| # TODO felipemello: Currently only support single dataset | ||
| if len(dataset_configs) > 1: | ||
| raise ValueError( | ||
| f"Multiple training datasets not supported yet. " | ||
| f"Got {len(dataset_configs)} datasets. " | ||
| ) | ||
|
|
||
| dataset_config = dataset_configs[0] | ||
|
|
||
| # TODO: Evaluate if tokenizers should be created once and shared for every dataset | ||
| # Load tokenizer | ||
| tokenizer = HuggingFaceModelTokenizer( | ||
| tokenizer_json_path=os.path.join( | ||
| self.job_config.model.hf_assets_path, "tokenizer.json" | ||
|
|
@@ -165,18 +215,26 @@ def setup_data(self): | |
| ), | ||
| ) | ||
|
|
||
| # Get DP mesh for data sharding | ||
| dp_mesh = None | ||
| if self.parallel_dims is not None and self.parallel_dims.dp_enabled: | ||
| dp_mesh = self.parallel_dims.world_mesh.get_group("dp") | ||
|
|
||
| # Pass config directly to dataset constructor | ||
| dataset = sft_iterable_dataset( | ||
| model_transform=tokenizer, | ||
| message_transform=AlpacaToMessages(), | ||
| path="yahma/alpaca-cleaned", | ||
| split="train", | ||
| dp_mesh=dp_mesh, | ||
| **dataset_config, | ||
| ) | ||
|
|
||
| packer = TextPacker(padding_idx=0) | ||
| dataset = PackedDataset( | ||
| dataset=dataset, | ||
| packer=packer, | ||
| target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model | ||
| ) | ||
|
|
||
| dataloader = StatefulDataLoader( | ||
| dataset=dataset, | ||
| batch_size=self.job_config.training.local_batch_size, | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -192,8 +250,12 @@ def setup_data(self): | |
| return dataloader | ||
|
|
||
| def forward_backward( | ||
| self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, | ||
| skip_backward: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Forward pass with optional backward.""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: No need for this comment |
||
| model_parts = self.model_parts | ||
| parallel_dims = self.parallel_dims | ||
|
|
||
|
|
@@ -230,10 +292,15 @@ def forward_backward( | |
| # accumulate losses across pipeline microbatches | ||
| # TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
| loss = ( | ||
| torch.mean(torch.stack(losses)).to(self.device) | ||
| torch.sum(torch.stack(losses)).to(self.device) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you change this? |
||
| if self.pp_has_last_stage | ||
| else torch.tensor([-1.0], device=self.device) | ||
| else torch.tensor(-1.0, device=self.device) | ||
| ) | ||
|
|
||
| # TODO: PP requires gradients enabled and cant deactive with no_grad | ||
| if skip_backward: | ||
| loss = loss.detach() | ||
|
|
||
| else: | ||
| # Non-PP forward / backward | ||
| with self.train_context(optional_context_parallel_ctx): | ||
|
|
@@ -243,7 +310,10 @@ def forward_backward( | |
| loss = self.loss_fn(pred, labels) | ||
| # need to free to before bwd to avoid peaking memory | ||
| del pred | ||
| loss.backward() | ||
|
|
||
| # Only run backward if requested. Useful for eval. | ||
| if not skip_backward: | ||
| loss.backward() | ||
|
|
||
| return loss | ||
|
|
||
|
|
@@ -256,15 +326,142 @@ def train_step(self, batch) -> None: | |
| # ) as grad_acc: | ||
| labels = batch.pop("labels") | ||
| loss = self.forward_backward(batch, labels) | ||
| loss = loss.item() | ||
|
|
||
| record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) | ||
| logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||
| if self.rank_should_record_loss: | ||
| loss_val = loss.item() | ||
| record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN) | ||
| if self.current_step % self.log_every_n_steps == 0: | ||
| logger.info( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought this was handled by the MetricLogger? |
||
| f"step {self.current_step} / {self.num_training_steps} | Loss: {loss_val}" | ||
| ) | ||
|
|
||
| # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
| # self.pbar.update(1) | ||
| self.optimizers.step() | ||
| self.lr_schedulers.step() | ||
|
|
||
| async def evaluate(self) -> None: | ||
| """Run evaluation on multiple datasets, one at a time. | ||
|
|
||
| 1. Set models to eval mode | ||
| 2. For each eval dataset: | ||
| - Create fresh iterator (starts from epoch 0) | ||
| - Use StopAfterOneEpoch to iterate until epoch boundary. This utility | ||
| is necessary for infinite iterable dataset, since epoch boundaries are not known. | ||
| - Respect max_eval_steps cap if configured | ||
| - Record loss and step metrics (on dp rank only) | ||
| 3. Restore models to train mode | ||
| """ | ||
|
|
||
| # Set models to eval mode | ||
| for model_part in self.model_parts: | ||
| model_part.eval() | ||
|
|
||
| # Get DP process group for epoch synchronization | ||
| dp_mesh = None | ||
| if self.parallel_dims is not None and self.parallel_dims.dp_enabled: | ||
| dp_mesh = self.parallel_dims.world_mesh.get_group("dp") | ||
|
|
||
| # For non-PP: disable gradients to save memory | ||
| # TODO: For PP, if disabling gradients, throws error | ||
| maybe_no_grad = ( | ||
| contextlib.nullcontext() | ||
| if self.parallel_dims.pp_enabled | ||
| else torch.no_grad() | ||
| ) | ||
|
|
||
| # Evaluate each dataset sequentially | ||
| all_dataset_losses = [] | ||
| all_dataset_steps = [] | ||
| for dataset_name, val_dataloader in self.val_dataloaders.items(): | ||
| logger.info(f"=====Evaluating dataset: {dataset_name}=====") | ||
|
|
||
| # Evaluation loop for this dataset | ||
| total_loss = torch.tensor(0.0, device=self.device) | ||
| num_steps = 0 | ||
|
|
||
| # NOTE: Assumes batch contains field "metrics" | ||
| batch_iter = StopAfterOneEpoch( | ||
| iter=iter(val_dataloader), # Fresh iterator from epoch 0, | ||
| device=self.device, | ||
| dp_mesh=dp_mesh, | ||
| ) | ||
|
|
||
| with maybe_no_grad: | ||
| for batch in batch_iter: | ||
| # if max_eval_steps>len(dataset), it will be stopped earlier by StopAfterOneEpoch. | ||
| if ( | ||
| self.max_eval_steps is not None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if max eval steps > num steps per epoch?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch_iter will stop first, i can add a comment to clarify |
||
| and num_steps >= self.max_eval_steps | ||
| ): | ||
| logger.info( | ||
| f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}" | ||
| ) | ||
| break | ||
|
|
||
| # Move tensors to device | ||
| for key, value in batch.items(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a helper function to do this. |
||
| if isinstance(value, torch.Tensor): | ||
| batch[key] = value.to(self.device) | ||
|
|
||
| # Process batch | ||
| labels = batch.pop("labels") | ||
| loss = self.forward_backward(batch, labels, skip_backward=True) | ||
| total_loss += loss | ||
| num_steps += 1 | ||
|
|
||
| # Log progress | ||
| if ( | ||
| self.rank_should_record_loss | ||
| and num_steps % self.log_every_n_steps == 0 | ||
| ): | ||
| loss_val = loss.item() | ||
| logger.info( | ||
| f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}" | ||
| ) | ||
|
|
||
| # log loss | ||
| avg_loss = (total_loss / max(num_steps, 1)).item() | ||
| all_dataset_losses.append(avg_loss) | ||
| all_dataset_steps.append(num_steps) | ||
| logger.info( | ||
| f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}" | ||
| ) | ||
| if self.rank_should_record_loss: | ||
| record_metric( | ||
| f"evaluate/dataset_{dataset_name}_avg_loss", | ||
| avg_loss, | ||
| Reduce.MEAN, | ||
| ) | ||
|
|
||
| # Record macro and micro average losses across datasets (only if multiple datasets) | ||
| if self.rank_should_record_loss and len(all_dataset_losses) > 1: | ||
| # Macro: same weight for all datasets | ||
| macro_avg_loss = sum(all_dataset_losses) / len(all_dataset_losses) | ||
| record_metric("evaluate/macro_avg_loss", macro_avg_loss, Reduce.MEAN) | ||
|
|
||
| # Micro: weighted mean by dataset size | ||
| total_steps = sum(all_dataset_steps) | ||
| micro_avg_loss = ( | ||
| sum( | ||
| loss * steps | ||
| for loss, steps in zip(all_dataset_losses, all_dataset_steps) | ||
| ) | ||
| / total_steps | ||
| ) | ||
| record_metric("evaluate/micro_avg_loss", micro_avg_loss, Reduce.MEAN) | ||
|
|
||
| logger.info( | ||
| f"Macro avg loss (unweighted): {macro_avg_loss:.4f}, " | ||
| f"Micro avg loss (weighted): {micro_avg_loss:.4f}" | ||
| ) | ||
|
|
||
| # Restore train mode | ||
| for model_part in self.model_parts: | ||
| model_part.train() | ||
|
|
||
| logger.info("==Evaluation complete==") | ||
|
|
||
| @endpoint | ||
| async def train(self) -> None: | ||
| dataloader = iter(self.train_dataloader) | ||
|
|
@@ -289,18 +486,28 @@ async def train(self) -> None: | |
| # self.profiler.step() | ||
| self.current_step += 1 | ||
|
|
||
| # Flush metrics | ||
| if self._rank == 0: | ||
| logger.debug(f"Flushing metrics at step {self.current_step}") | ||
| await self.mlogger.flush.call_one(global_step=self.current_step) | ||
| # Run evaluation periodically if enabled | ||
| if ( | ||
| self.validation_enabled | ||
| and self.current_step % self.eval_every_n_steps == 0 | ||
| ): | ||
| await self.evaluate() | ||
|
|
||
| self.checkpointer.save( | ||
| curr_step=self.current_step, | ||
| last_step=self.current_step == self.num_training_steps, | ||
| ) | ||
|
|
||
| # Flush metrics | ||
| if self._rank == 0: | ||
| await self.mlogger.flush.call_one(global_step=self.current_step) | ||
|
|
||
| # self.pbar.close() | ||
|
|
||
| if self.validation_enabled: | ||
| logger.info("Running final evaluation at end of training...") | ||
| await self.evaluate() | ||
|
|
||
| @endpoint | ||
| async def cleanup(self) -> None: | ||
| if self.checkpointer: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this mandatory, no getter?