Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from forge.actors.generator import Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.actors.trainer import TitanTrainer
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
Expand Down Expand Up @@ -318,7 +318,7 @@ async def main(cfg: DictConfig):
) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
RLTrainer.options(**cfg.actors.trainer).as_actor(
TitanTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
The Trainer manages model training in TorchForge, built on top of TorchTitan.
It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows.

## RLTrainer
## TitanTrainer

```{eval-rst}
.. autoclass:: RLTrainer
.. autoclass:: TitanTrainer
:members: train_step, push_weights, cleanup
:exclude-members: __init__
```

## Configuration

The RLTrainer uses TorchTitan's configuration system with the following components:
The TitanTrainer uses TorchTitan's configuration system with the following components:

### Job Configuration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ graph LR
S3["RewardActor"]
S4["ReferenceModel"]
S5["ReplayBuffer"]
S6["RLTrainer"]
S6["TitanTrainer"]
end

C1 --> S1
Expand Down Expand Up @@ -306,7 +306,7 @@ TorchForge handles behind the scenes:
from forge.actors.generator import Generator as Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.reference_model import ReferenceModel
from forge.actors.trainer import RLTrainer
from forge.actors.trainer import TitanTrainer
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
from forge.data.rewards import MathReward, ThinkingReward
import asyncio
Expand Down Expand Up @@ -348,7 +348,7 @@ group_size = 1
}
),
# Trainer actor with GPU
RLTrainer.options(procs=1, with_gpus=True).as_actor(
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
# Trainer config would come from YAML in real usage
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"},
optimizer={"name": "AdamW", "lr": 1e-5},
Expand Down Expand Up @@ -378,12 +378,12 @@ group_size = 1

TorchForge has two types of distributed components:
- **Services**: Multiple replicas with automatic load balancing (like Policy, RewardActor)
- **Actors**: Single instances that handle their own internal distribution (like RLTrainer, ReplayBuffer)
- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer)

We cover this distinction in detail in Part 2, but for now this explains the scaling patterns:
- Policy service: num_replicas=8 for high inference demand
- RewardActor service: num_replicas=16 for parallel evaluation
- RLTrainer actor: Single instance with internal distributed training
- TitanTrainer actor: Single instance with internal distributed training


### Fault Tolerance
Expand Down
10 changes: 5 additions & 5 deletions docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ async def simple_rl_step():
if batch is not None:
print("Training on batch...")
inputs, targets = batch # GRPO returns (inputs, targets) tuple
loss = await trainer.train_step.call(inputs, targets) # RLTrainer is an actor
loss = await trainer.train_step.call(inputs, targets) # TitanTrainer is an actor
print(f"Training loss: {loss}")
return loss
else:
Expand Down Expand Up @@ -507,7 +507,7 @@ reward_actor = await RewardActor.options(
)

# Training needs fewer but more powerful replicas
trainer = await RLTrainer.options(
trainer = await TitanTrainer.options(
procs=1, with_gpus=True # Fewer but GPU-heavy
).as_actor( # Trainer typically uses .as_actor() not .as_service()
model={"name": "qwen3", "flavor": "1.7B"},
Expand Down Expand Up @@ -580,7 +580,7 @@ import torch
from forge.actors.generator import Generator as Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.actors.trainer import TitanTrainer
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
from forge.data.rewards import MathReward, ThinkingReward

Expand All @@ -603,7 +603,7 @@ print("Initializing all services...")
engine_config={"model": "Qwen/Qwen3-1.7B", "tensor_parallel_size": 1},
sampling_config={"n": 1, "max_tokens": 512}
),
RLTrainer.options(procs=1, with_gpus=True).as_actor(
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"},
optimizer={"name": "AdamW", "lr": 1e-5},
training={"local_batch_size": 2, "seq_len": 2048}
Expand Down Expand Up @@ -667,7 +667,7 @@ print("Shutting down services...")
await asyncio.gather(
DatasetActor.shutdown(dataloader),
policy.shutdown(),
RLTrainer.shutdown(trainer),
TitanTrainer.shutdown(trainer),
ReplayBuffer.shutdown(replay_buffer),
ComputeAdvantages.shutdown(compute_advantages),
ReferenceModel.shutdown(ref_model),
Expand Down
15 changes: 14 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings

__all__ = [
"Generator",
"RLTrainer",
"TitanTrainer",
"RLTrainer", # Deprecated, use TitanTrainer
"ReplayBuffer",
"ReferenceModel",
"SandboxedPythonCoder",
Expand All @@ -18,7 +21,17 @@ def __getattr__(name):
from .generator import Generator

return Generator
elif name == "TitanTrainer":
from .trainer import TitanTrainer

return TitanTrainer
elif name == "RLTrainer":
warnings.warn(
"RLTrainer is deprecated and will be removed in a future version. "
"Please use TitanTrainer instead.",
FutureWarning,
stacklevel=2,
)
from .trainer import RLTrainer

return RLTrainer
Expand Down
23 changes: 23 additions & 0 deletions src/forge/actors/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings

from .titan import TitanTrainer

__all__ = ["TitanTrainer", "RLTrainer"]


def __getattr__(name):
if name == "RLTrainer":
warnings.warn(
"RLTrainer is deprecated and will be removed in a future version. "
"Please use TitanTrainer instead.",
FutureWarning,
stacklevel=2,
)
return TitanTrainer
raise AttributeError(f"module {__name__} has no attribute {name}")
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@


@dataclass
class RLTrainer(ForgeActor):
"""A reinforcement learning trainer actor for policy optimization training.
class TitanTrainer(ForgeActor):
"""A generic trainer actor implementation built on top of TorchTitan.

Built on top of TorchTitan's training engine, this actor provides a complete training
loop for reinforcement learning. It performs forward and backward passes with gradient
Expand Down
24 changes: 12 additions & 12 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torchstore as ts
from forge.actors.generator import Generator

from forge.actors.trainer import RLTrainer
from forge.actors.trainer import TitanTrainer
from forge.controller.provisioner import init_provisioner

from forge.controller.service.service import uuid
Expand Down Expand Up @@ -50,7 +50,7 @@
TEST_DCP_DIR = "test_dcp_tmp"


class MockRLTrainer(RLTrainer):
class MockTitanTrainer(TitanTrainer):
@endpoint
async def zero_out_model_states(self):
"""This simply sets all model weights to zero."""
Expand All @@ -59,7 +59,7 @@ async def zero_out_model_states(self):
for k in sd.keys():
if not torch.is_floating_point(sd[k]):
logger.info(
f"[MockRLTrainer] zero_out_model_states(): skipping non-float param {k}"
f"[MockTitanTrainer] zero_out_model_states(): skipping non-float param {k}"
)
continue
sd[k] *= 0.0
Expand Down Expand Up @@ -199,22 +199,22 @@ async def _setup_and_teardown(request):
)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

policy, rl_trainer = await asyncio.gather(
policy, titan_trainer = await asyncio.gather(
*[
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
MockTitanTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
]
)

yield policy, rl_trainer
yield policy, titan_trainer

# ---- teardown ---- #
logger.info("Shutting down services and cleaning up DCP directory..")

await asyncio.gather(
policy.shutdown(),
ts.shutdown(),
RLTrainer.shutdown(rl_trainer),
TitanTrainer.shutdown(titan_trainer),
)

# Cleanup DCP directory
Expand All @@ -235,7 +235,7 @@ class TestWeightSync:
@requires_cuda
async def test_sanity_check(self, _setup_and_teardown):
"""
Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
Sanity check for weight sync sharding between TitanTrainer and Policy for a given model config.

The check performs the following steps:
- Initialize trainer and push weights v0 (original huggingface ckpt)
Expand All @@ -245,15 +245,15 @@ async def test_sanity_check(self, _setup_and_teardown):

"""

policy, rl_trainer = _setup_and_teardown
policy, titan_trainer = _setup_and_teardown

v0 = uuid.uuid4().int
v1 = v0 + 1

await rl_trainer.push_weights.call(policy_version=v0)
await titan_trainer.push_weights.call(policy_version=v0)
# Setting everything to zero
await rl_trainer.zero_out_model_states.call()
await rl_trainer.push_weights.call(policy_version=v1)
await titan_trainer.zero_out_model_states.call()
await titan_trainer.push_weights.call(policy_version=v1)
await policy.save_model_params.fanout()

# Sanity check that before update all the tests pass
Expand Down
4 changes: 2 additions & 2 deletions tests/sandbox/rl_trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import torchstore as ts
from forge.actors.trainer import RLTrainer
from forge.actors.trainer import TitanTrainer
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
Expand Down Expand Up @@ -182,7 +182,7 @@ async def main(cfg: DictConfig):
await ts.initialize(strategy=ts.ControllerStorageVolumes())
# Initialize trainer only
print("Initializing trainer...")
trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(
trainer = await TitanTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
)
print("Trainer initialized successfully with following configs!")
Expand Down
Loading