Skip to content

Commit bb57589

Browse files
allenwang28Allen Wang
andauthored
Rename RLTrainer to TitanTrainer (#538)
Co-authored-by: Allen Wang <[email protected]>
1 parent 29c6584 commit bb57589

File tree

9 files changed

+68
-32
lines changed

9 files changed

+68
-32
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from forge.actors.generator import Generator
2424
from forge.actors.reference_model import ReferenceModel
2525
from forge.actors.replay_buffer import ReplayBuffer
26-
from forge.actors.trainer import RLTrainer
26+
from forge.actors.trainer import TitanTrainer
2727
from forge.controller.actor import ForgeActor
2828
from forge.controller.provisioner import init_provisioner, shutdown
2929
from forge.data.rewards import MathReward, ThinkingReward
@@ -318,7 +318,7 @@ async def main(cfg: DictConfig):
318318
) = await asyncio.gather(
319319
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
320320
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
321-
RLTrainer.options(**cfg.actors.trainer).as_actor(
321+
TitanTrainer.options(**cfg.actors.trainer).as_actor(
322322
**cfg.trainer, loss=simple_grpo_loss
323323
),
324324
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(

docs/source/api_trainer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
The Trainer manages model training in TorchForge, built on top of TorchTitan.
88
It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows.
99

10-
## RLTrainer
10+
## TitanTrainer
1111

1212
```{eval-rst}
13-
.. autoclass:: RLTrainer
13+
.. autoclass:: TitanTrainer
1414
:members: train_step, push_weights, cleanup
1515
:exclude-members: __init__
1616
```
1717

1818
## Configuration
1919

20-
The RLTrainer uses TorchTitan's configuration system with the following components:
20+
The TitanTrainer uses TorchTitan's configuration system with the following components:
2121

2222
### Job Configuration
2323

docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ graph LR
9696
S3["RewardActor"]
9797
S4["ReferenceModel"]
9898
S5["ReplayBuffer"]
99-
S6["RLTrainer"]
99+
S6["TitanTrainer"]
100100
end
101101
102102
C1 --> S1
@@ -306,7 +306,7 @@ TorchForge handles behind the scenes:
306306
from forge.actors.generator import Generator as Policy
307307
from forge.actors.replay_buffer import ReplayBuffer
308308
from forge.actors.reference_model import ReferenceModel
309-
from forge.actors.trainer import RLTrainer
309+
from forge.actors.trainer import TitanTrainer
310310
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
311311
from forge.data.rewards import MathReward, ThinkingReward
312312
import asyncio
@@ -348,7 +348,7 @@ group_size = 1
348348
}
349349
),
350350
# Trainer actor with GPU
351-
RLTrainer.options(procs=1, with_gpus=True).as_actor(
351+
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
352352
# Trainer config would come from YAML in real usage
353353
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"},
354354
optimizer={"name": "AdamW", "lr": 1e-5},
@@ -378,12 +378,12 @@ group_size = 1
378378

379379
TorchForge has two types of distributed components:
380380
- **Services**: Multiple replicas with automatic load balancing (like Policy, RewardActor)
381-
- **Actors**: Single instances that handle their own internal distribution (like RLTrainer, ReplayBuffer)
381+
- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer)
382382

383383
We cover this distinction in detail in Part 2, but for now this explains the scaling patterns:
384384
- Policy service: num_replicas=8 for high inference demand
385385
- RewardActor service: num_replicas=16 for parallel evaluation
386-
- RLTrainer actor: Single instance with internal distributed training
386+
- TitanTrainer actor: Single instance with internal distributed training
387387

388388

389389
### Fault Tolerance

docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ async def simple_rl_step():
470470
if batch is not None:
471471
print("Training on batch...")
472472
inputs, targets = batch # GRPO returns (inputs, targets) tuple
473-
loss = await trainer.train_step.call(inputs, targets) # RLTrainer is an actor
473+
loss = await trainer.train_step.call(inputs, targets) # TitanTrainer is an actor
474474
print(f"Training loss: {loss}")
475475
return loss
476476
else:
@@ -507,7 +507,7 @@ reward_actor = await RewardActor.options(
507507
)
508508

509509
# Training needs fewer but more powerful replicas
510-
trainer = await RLTrainer.options(
510+
trainer = await TitanTrainer.options(
511511
procs=1, with_gpus=True # Fewer but GPU-heavy
512512
).as_actor( # Trainer typically uses .as_actor() not .as_service()
513513
model={"name": "qwen3", "flavor": "1.7B"},
@@ -580,7 +580,7 @@ import torch
580580
from forge.actors.generator import Generator as Policy
581581
from forge.actors.reference_model import ReferenceModel
582582
from forge.actors.replay_buffer import ReplayBuffer
583-
from forge.actors.trainer import RLTrainer
583+
from forge.actors.trainer import TitanTrainer
584584
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
585585
from forge.data.rewards import MathReward, ThinkingReward
586586

@@ -603,7 +603,7 @@ print("Initializing all services...")
603603
engine_config={"model": "Qwen/Qwen3-1.7B", "tensor_parallel_size": 1},
604604
sampling_config={"n": 1, "max_tokens": 512}
605605
),
606-
RLTrainer.options(procs=1, with_gpus=True).as_actor(
606+
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
607607
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"},
608608
optimizer={"name": "AdamW", "lr": 1e-5},
609609
training={"local_batch_size": 2, "seq_len": 2048}
@@ -667,7 +667,7 @@ print("Shutting down services...")
667667
await asyncio.gather(
668668
DatasetActor.shutdown(dataloader),
669669
policy.shutdown(),
670-
RLTrainer.shutdown(trainer),
670+
TitanTrainer.shutdown(trainer),
671671
ReplayBuffer.shutdown(replay_buffer),
672672
ComputeAdvantages.shutdown(compute_advantages),
673673
ReferenceModel.shutdown(ref_model),

src/forge/actors/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
8+
79
__all__ = [
810
"Generator",
9-
"RLTrainer",
11+
"TitanTrainer",
12+
"RLTrainer", # Deprecated, use TitanTrainer
1013
"ReplayBuffer",
1114
"ReferenceModel",
1215
"SandboxedPythonCoder",
@@ -18,7 +21,17 @@ def __getattr__(name):
1821
from .generator import Generator
1922

2023
return Generator
24+
elif name == "TitanTrainer":
25+
from .trainer import TitanTrainer
26+
27+
return TitanTrainer
2128
elif name == "RLTrainer":
29+
warnings.warn(
30+
"RLTrainer is deprecated and will be removed in a future version. "
31+
"Please use TitanTrainer instead.",
32+
FutureWarning,
33+
stacklevel=2,
34+
)
2235
from .trainer import RLTrainer
2336

2437
return RLTrainer
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import warnings
8+
9+
from .titan import TitanTrainer
10+
11+
__all__ = ["TitanTrainer", "RLTrainer"]
12+
13+
14+
def __getattr__(name):
15+
if name == "RLTrainer":
16+
warnings.warn(
17+
"RLTrainer is deprecated and will be removed in a future version. "
18+
"Please use TitanTrainer instead.",
19+
FutureWarning,
20+
stacklevel=2,
21+
)
22+
return TitanTrainer
23+
raise AttributeError(f"module {__name__} has no attribute {name}")

src/forge/actors/trainer.py renamed to src/forge/actors/trainer/titan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353

5454

5555
@dataclass
56-
class RLTrainer(ForgeActor):
57-
"""A reinforcement learning trainer actor for policy optimization training.
56+
class TitanTrainer(ForgeActor):
57+
"""A generic trainer actor implementation built on top of TorchTitan.
5858
5959
Built on top of TorchTitan's training engine, this actor provides a complete training
6060
loop for reinforcement learning. It performs forward and backward passes with gradient

tests/integration_tests/test_policy_update.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torchstore as ts
1717
from forge.actors.generator import Generator
1818

19-
from forge.actors.trainer import RLTrainer
19+
from forge.actors.trainer import TitanTrainer
2020
from forge.controller.provisioner import init_provisioner
2121

2222
from forge.controller.service.service import uuid
@@ -50,7 +50,7 @@
5050
TEST_DCP_DIR = "test_dcp_tmp"
5151

5252

53-
class MockRLTrainer(RLTrainer):
53+
class MockTitanTrainer(TitanTrainer):
5454
@endpoint
5555
async def zero_out_model_states(self):
5656
"""This simply sets all model weights to zero."""
@@ -59,7 +59,7 @@ async def zero_out_model_states(self):
5959
for k in sd.keys():
6060
if not torch.is_floating_point(sd[k]):
6161
logger.info(
62-
f"[MockRLTrainer] zero_out_model_states(): skipping non-float param {k}"
62+
f"[MockTitanTrainer] zero_out_model_states(): skipping non-float param {k}"
6363
)
6464
continue
6565
sd[k] *= 0.0
@@ -199,22 +199,22 @@ async def _setup_and_teardown(request):
199199
)
200200
await ts.initialize(strategy=ts.ControllerStorageVolumes())
201201

202-
policy, rl_trainer = await asyncio.gather(
202+
policy, titan_trainer = await asyncio.gather(
203203
*[
204204
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
205-
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
205+
MockTitanTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
206206
]
207207
)
208208

209-
yield policy, rl_trainer
209+
yield policy, titan_trainer
210210

211211
# ---- teardown ---- #
212212
logger.info("Shutting down services and cleaning up DCP directory..")
213213

214214
await asyncio.gather(
215215
policy.shutdown(),
216216
ts.shutdown(),
217-
RLTrainer.shutdown(rl_trainer),
217+
TitanTrainer.shutdown(titan_trainer),
218218
)
219219

220220
# Cleanup DCP directory
@@ -235,7 +235,7 @@ class TestWeightSync:
235235
@requires_cuda
236236
async def test_sanity_check(self, _setup_and_teardown):
237237
"""
238-
Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
238+
Sanity check for weight sync sharding between TitanTrainer and Policy for a given model config.
239239
240240
The check performs the following steps:
241241
- Initialize trainer and push weights v0 (original huggingface ckpt)
@@ -245,15 +245,15 @@ async def test_sanity_check(self, _setup_and_teardown):
245245
246246
"""
247247

248-
policy, rl_trainer = _setup_and_teardown
248+
policy, titan_trainer = _setup_and_teardown
249249

250250
v0 = uuid.uuid4().int
251251
v1 = v0 + 1
252252

253-
await rl_trainer.push_weights.call(policy_version=v0)
253+
await titan_trainer.push_weights.call(policy_version=v0)
254254
# Setting everything to zero
255-
await rl_trainer.zero_out_model_states.call()
256-
await rl_trainer.push_weights.call(policy_version=v1)
255+
await titan_trainer.zero_out_model_states.call()
256+
await titan_trainer.push_weights.call(policy_version=v1)
257257
await policy.save_model_params.fanout()
258258

259259
# Sanity check that before update all the tests pass

tests/sandbox/rl_trainer/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
import torchstore as ts
13-
from forge.actors.trainer import RLTrainer
13+
from forge.actors.trainer import TitanTrainer
1414
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
1515
from forge.controller.provisioner import init_provisioner, shutdown
1616
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -182,7 +182,7 @@ async def main(cfg: DictConfig):
182182
await ts.initialize(strategy=ts.ControllerStorageVolumes())
183183
# Initialize trainer only
184184
print("Initializing trainer...")
185-
trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(
185+
trainer = await TitanTrainer.options(**cfg.actors.trainer).as_actor(
186186
**cfg.trainer, loss=simple_grpo_loss
187187
)
188188
print("Trainer initialized successfully with following configs!")

0 commit comments

Comments
 (0)