Skip to content
12 changes: 11 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ training:
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[:95%]"

eval:
eval_every_n_steps: 50 # null = disabled
max_eval_steps: null # null = run until epoch completes
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[95%:]"

parallelism:
data_parallel_replicate_degree: 1
Expand Down Expand Up @@ -62,6 +71,7 @@ metric_logging:
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce


# profiling:
# enable_profiling: false

Expand Down
209 changes: 186 additions & 23 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import StopAfterOneEpoch
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.util.config import parse

Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

self._init_dist()
super().__init__(job_config)

Expand Down Expand Up @@ -122,28 +124,67 @@ def record_batch_metrics(self, data_metrics: list):

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()

# metric logger
self.mlogger = await self.setup_metric_logger()

# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )
# Load training datasets
logger.info("Setting training datasets")
train_datasets_config = self.job_config.training.datasets
self.train_dataloader = self.setup_data(train_datasets_config)

# Load eval datasets
eval_config = self.job_config.get("eval", {})
self.val_dataloaders = {}
self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None)
max_eval_steps = eval_config.get("max_eval_steps", None)
self.max_eval_steps = (
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
)
self.validation_enabled = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eval_every_n_steps, is there a check to break when we exhaust the steps? If we don't have the epoch metric, shouldn't this be the metric to break the eval loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do:


for batch in StopAfterOneEpoch(val_dataloader):
  # Check max_eval_steps limit
  if (
      self.max_eval_steps is not None
      and num_steps >= self.max_eval_steps
  ):
      break

So its whichever comes first , one epoch or self.max_eval_steps.

Regarding what happens if there is no "num_epochs" metric. This would only happen if the user replaces our dataset class with a new one. This is completely possible, but they can easily add the "num_epochs" metric if they have this level of expertise, or delete "StopAfterOneEpoch" from main.py

Worst case, we can add checks if someone complain.

I wanted to avoid adding complexity adding more if/else here.

wdyt?

self.eval_every_n_steps is not None and self.eval_every_n_steps > 0
)
if self.validation_enabled:
logger.info("Setting eval datasets")
self.eval_datasets_config = eval_config.datasets

for i, dataset_config in enumerate(self.eval_datasets_config):
ds_name = dataset_config.get("dataset_name", i)

# TODO: Support separate eval batch size from config (eval.local_batch_size)
dataloader = self.setup_data([dataset_config])
self.val_dataloaders[ds_name] = dataloader

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
self.checkpointer.load(step=self.current_step)

# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
"""Instantiates datasets and returns a StatefulDataLoader.

Args:
dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`.

Returns:
StatefulDataLoader

Raises:
ValueError: If multiple datasets provided (not yet supported)
"""
# TODO felipemello: Currently only support single dataset
if len(dataset_configs) > 1:
raise ValueError(
f"Multiple training datasets not supported yet. "
f"Got {len(dataset_configs)} datasets. "
)

dataset_config = dataset_configs[0]

# TODO: Evaluate if tokenizers should be created once and shared for every dataset
# Load tokenizer
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
Expand All @@ -165,18 +206,26 @@ def setup_data(self):
),
)

# Get DP mesh for data sharding
dp_mesh = None
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")

# Pass config directly to dataset constructor
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
dp_mesh=dp_mesh,
**dataset_config,
)

packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
dataset=dataset,
packer=packer,
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
)

dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using drop_last = True here? if the dataset size is not divisible by batch_size * world_size, some ranks will have fewer batches which could lead to potential deadlock. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is true for mapstyle, but for iterative datasets, this is a no-op, so i thought it was deceiving to have it there, i.e. "why do we need the StopAfterOneEpoch utility if we already have drop_last=True".

One can make the argument: "what if the user implements their own dataset class as map style?". Our PackedDataset and InterleavedDataset would still be iterable datasets, so the input to the dataloader would always be an iterable.

Let me know if that makes sense.

Expand All @@ -192,8 +241,12 @@ def setup_data(self):
return dataloader

def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
skip_backward: bool = False,
) -> torch.Tensor:
"""Forward pass with optional backward."""
model_parts = self.model_parts
parallel_dims = self.parallel_dims

Expand All @@ -214,6 +267,7 @@ def forward_backward(

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
# Note: PP backward only happens if not in torch.no_grad() context
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
Expand All @@ -235,15 +289,18 @@ def forward_backward(
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward / backward
# Non-PP forward / backward - must happen inside same context
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# Only run backward if requested. Useful for eval.
if not skip_backward:
loss.backward()

return loss

Expand All @@ -259,12 +316,108 @@ def train_step(self, batch) -> None:
loss = loss.item()

record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
if self.current_step % 10 == 0:
logger.info(
f"step {self.current_step} / {self.num_training_steps} | Loss: {loss}"
)

# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()

async def evaluate(self) -> None:
"""Run evaluation on multiple datasets, one at a time.

1. Set models to eval mode
2. For each eval dataset:
- Create fresh iterator (starts from epoch 0)
- Use StopAfterOneEpoch to iterate until epoch boundary. This utility
is necessary for infinite iterable dataset, since epoch boundaries are not known.
- Respect max_eval_steps cap if configured
- Record loss and step metrics (on dp rank only)
3. Restore models to train mode
"""

# Set models to eval mode
for model_part in self.model_parts:
model_part.eval()

# Get DP process group for epoch synchronization
dp_mesh = None
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")

# Evaluate each dataset sequentially
for dataset_name, val_dataloader in self.val_dataloaders.items():
logger.info(f"=====Evaluating dataset: {dataset_name}=====")

# Evaluation loop for this dataset
total_loss = torch.tensor(0.0, device=self.device)
num_steps = 0

# NOTE: Assumes batch contains field "metrics"
batch_iter = StopAfterOneEpoch(
iter=iter(val_dataloader), # Fresh iterator from epoch 0,
device=self.device,
dp_mesh=dp_mesh,
)

with torch.no_grad():
for batch in batch_iter:
# Check max_eval_steps limit
if (
self.max_eval_steps is not None
and num_steps >= self.max_eval_steps
):
logger.info(
f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}"
)
break

# Move tensors to device
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(self.device)

# Process batch
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels, skip_backward=True)
total_loss += loss
num_steps += 1

# Log progress (rank 0 only)
if num_steps % 50 == 0:
loss_val = loss.item()
logger.info(
f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}"
)

# Compute average loss
avg_loss = (total_loss / max(num_steps, 1)).item()
logger.info(
f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}"
)
# Record metrics only on DP rank 0 to avoid double counting
# record_metric aggregates across all processes via monarch
should_record = True
if dp_mesh is not None:
dp_rank = torch.distributed.get_rank(group=dp_mesh)
should_record = dp_rank == 0
Comment on lines +403 to +406
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic here is wrong. we should record for every dp, and have checks on other types. Will do it on monday.


if should_record:
record_metric(
f"evaluate/dataset_{dataset_name}_loss",
avg_loss,
Reduce.MEAN,
)

# Restore train mode
for model_part in self.model_parts:
model_part.train()

logger.info("==Evaluation complete==")

@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
Expand All @@ -289,18 +442,28 @@ async def train(self) -> None:
# self.profiler.step()
self.current_step += 1

# Flush metrics
if self._rank == 0:
logger.debug(f"Flushing metrics at step {self.current_step}")
await self.mlogger.flush.call_one(global_step=self.current_step)
# Run evaluation periodically if enabled
if (
self.validation_enabled
and self.current_step % self.eval_every_n_steps == 0
):
await self.evaluate()

self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# Flush metrics
if self._rank == 0:
await self.mlogger.flush.call_one(global_step=self.current_step)

# self.pbar.close()

if self.validation_enabled:
logger.info("Running final evaluation at end of training...")
await self.evaluate()

@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
Expand Down
11 changes: 10 additions & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ training:
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[:95%]"

eval:
eval_every_n_steps: 50 # null = disabled
max_eval_steps: null # null = run until epoch completes
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[95%:]"

parallelism:
data_parallel_replicate_degree: 1
Expand Down
Loading
Loading