Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ training:
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[:95%]"

eval:
eval_every_n_steps: 50 # null = disabled
max_eval_steps: null # null = run until epoch completes
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[95%:]"

parallelism:
data_parallel_replicate_degree: 1
Expand Down Expand Up @@ -62,6 +71,7 @@ metric_logging:
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce


# profiling:
# enable_profiling: false

Expand Down
261 changes: 234 additions & 27 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""

import asyncio

import contextlib
import logging
import math
import os
Expand All @@ -27,6 +27,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import StopAfterOneEpoch
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.util.config import parse

Expand Down Expand Up @@ -81,9 +82,19 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

self._init_dist()
super().__init__(job_config)

# For Pipeline Parallelism (PP): Only the last PP stage computes the actual loss
# For non-PP setups: All ranks compute loss
self.rank_should_record_loss = True
if hasattr(self, "pp_has_last_stage") and not self.pp_has_last_stage:
self.rank_should_record_loss = False

# Logging frequency
self.log_every_n_steps = self.job_config.get("log_every_n_steps", 10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this mandatory, no getter?


def _init_dist(self):
"""Initializes torch distributed.

Expand Down Expand Up @@ -122,28 +133,67 @@ def record_batch_metrics(self, data_metrics: list):

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()

# metric logger
self.mlogger = await self.setup_metric_logger()

# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )
# Load training datasets
logger.info("Setting training datasets")
train_datasets_config = self.job_config.training.datasets
self.train_dataloader = self.setup_data(train_datasets_config)

# Load eval datasets
eval_config = self.job_config.get("eval", {})
self.val_dataloaders = {}
self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No getter

max_eval_steps = eval_config.get("max_eval_steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No getter

self.max_eval_steps = (
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
)
self.validation_enabled = (
self.eval_every_n_steps is not None and self.eval_every_n_steps > 0
)
if self.validation_enabled:
logger.info("Setting eval datasets")
self.eval_datasets_config = eval_config.datasets

for i, dataset_config in enumerate(self.eval_datasets_config):
ds_name = dataset_config.get("dataset_name", i)

# TODO: Support separate eval batch size from config (eval.local_batch_size)
dataloader = self.setup_data([dataset_config])
self.val_dataloaders[ds_name] = dataloader

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
self.checkpointer.load(step=self.current_step)

# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
"""Instantiates datasets and returns a StatefulDataLoader.

Args:
dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`.

Returns:
StatefulDataLoader

Raises:
ValueError: If multiple datasets provided (not yet supported)
"""
# TODO felipemello: Currently only support single dataset
if len(dataset_configs) > 1:
raise ValueError(
f"Multiple training datasets not supported yet. "
f"Got {len(dataset_configs)} datasets. "
)

dataset_config = dataset_configs[0]

# TODO: Evaluate if tokenizers should be created once and shared for every dataset
# Load tokenizer
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
Expand All @@ -165,18 +215,26 @@ def setup_data(self):
),
)

# Get DP mesh for data sharding
dp_mesh = None
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")

# Pass config directly to dataset constructor
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
dp_mesh=dp_mesh,
**dataset_config,
)

packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
dataset=dataset,
packer=packer,
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
)

dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
Expand All @@ -192,8 +250,12 @@ def setup_data(self):
return dataloader

def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
skip_backward: bool = False,
) -> torch.Tensor:
"""Forward pass with optional backward."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: No need for this comment

model_parts = self.model_parts
parallel_dims = self.parallel_dims

Expand Down Expand Up @@ -230,10 +292,15 @@ def forward_backward(
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
torch.sum(torch.stack(losses)).to(self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change this?

if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
else torch.tensor(-1.0, device=self.device)
)

# TODO: PP requires gradients enabled and cant deactive with no_grad
if skip_backward:
loss = loss.detach()

else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
Expand All @@ -243,7 +310,10 @@ def forward_backward(
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# Only run backward if requested. Useful for eval.
if not skip_backward:
loss.backward()

return loss

Expand All @@ -256,15 +326,142 @@ def train_step(self, batch) -> None:
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)
loss = loss.item()

record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
if self.rank_should_record_loss:
loss_val = loss.item()
record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN)
if self.current_step % self.log_every_n_steps == 0:
logger.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was handled by the MetricLogger?

f"step {self.current_step} / {self.num_training_steps} | Loss: {loss_val}"
)

# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()

async def evaluate(self) -> None:
"""Run evaluation on multiple datasets, one at a time.

1. Set models to eval mode
2. For each eval dataset:
- Create fresh iterator (starts from epoch 0)
- Use StopAfterOneEpoch to iterate until epoch boundary. This utility
is necessary for infinite iterable dataset, since epoch boundaries are not known.
- Respect max_eval_steps cap if configured
- Record loss and step metrics (on dp rank only)
3. Restore models to train mode
"""

# Set models to eval mode
for model_part in self.model_parts:
model_part.eval()

# Get DP process group for epoch synchronization
dp_mesh = None
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")

# For non-PP: disable gradients to save memory
# TODO: For PP, if disabling gradients, throws error
maybe_no_grad = (
contextlib.nullcontext()
if self.parallel_dims.pp_enabled
else torch.no_grad()
)

# Evaluate each dataset sequentially
all_dataset_losses = []
all_dataset_steps = []
for dataset_name, val_dataloader in self.val_dataloaders.items():
logger.info(f"=====Evaluating dataset: {dataset_name}=====")

# Evaluation loop for this dataset
total_loss = torch.tensor(0.0, device=self.device)
num_steps = 0

# NOTE: Assumes batch contains field "metrics"
batch_iter = StopAfterOneEpoch(
iter=iter(val_dataloader), # Fresh iterator from epoch 0,
device=self.device,
dp_mesh=dp_mesh,
)

with maybe_no_grad:
for batch in batch_iter:
# if max_eval_steps>len(dataset), it will be stopped earlier by StopAfterOneEpoch.
if (
self.max_eval_steps is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if max eval steps > num steps per epoch?

Copy link
Contributor Author

@felipemello1 felipemello1 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_iter will stop first, i can add a comment to clarify

and num_steps >= self.max_eval_steps
):
logger.info(
f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}"
)
break

# Move tensors to device
for key, value in batch.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a helper function to do this.

if isinstance(value, torch.Tensor):
batch[key] = value.to(self.device)

# Process batch
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels, skip_backward=True)
total_loss += loss
num_steps += 1

# Log progress
if (
self.rank_should_record_loss
and num_steps % self.log_every_n_steps == 0
):
loss_val = loss.item()
logger.info(
f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}"
)

# log loss
avg_loss = (total_loss / max(num_steps, 1)).item()
all_dataset_losses.append(avg_loss)
all_dataset_steps.append(num_steps)
logger.info(
f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}"
)
if self.rank_should_record_loss:
record_metric(
f"evaluate/dataset_{dataset_name}_avg_loss",
avg_loss,
Reduce.MEAN,
)

# Record macro and micro average losses across datasets (only if multiple datasets)
if self.rank_should_record_loss and len(all_dataset_losses) > 1:
# Macro: same weight for all datasets
macro_avg_loss = sum(all_dataset_losses) / len(all_dataset_losses)
record_metric("evaluate/macro_avg_loss", macro_avg_loss, Reduce.MEAN)

# Micro: weighted mean by dataset size
total_steps = sum(all_dataset_steps)
micro_avg_loss = (
sum(
loss * steps
for loss, steps in zip(all_dataset_losses, all_dataset_steps)
)
/ total_steps
)
record_metric("evaluate/micro_avg_loss", micro_avg_loss, Reduce.MEAN)

logger.info(
f"Macro avg loss (unweighted): {macro_avg_loss:.4f}, "
f"Micro avg loss (weighted): {micro_avg_loss:.4f}"
)

# Restore train mode
for model_part in self.model_parts:
model_part.train()

logger.info("==Evaluation complete==")

@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
Expand All @@ -289,18 +486,28 @@ async def train(self) -> None:
# self.profiler.step()
self.current_step += 1

# Flush metrics
if self._rank == 0:
logger.debug(f"Flushing metrics at step {self.current_step}")
await self.mlogger.flush.call_one(global_step=self.current_step)
# Run evaluation periodically if enabled
if (
self.validation_enabled
and self.current_step % self.eval_every_n_steps == 0
):
await self.evaluate()

self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# Flush metrics
if self._rank == 0:
await self.mlogger.flush.call_one(global_step=self.current_step)

# self.pbar.close()

if self.validation_enabled:
logger.info("Running final evaluation at end of training...")
await self.evaluate()

@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
Expand Down
Loading
Loading