1616import torchstore as ts
1717from forge .actors .generator import Generator
1818
19- from forge .actors .trainer import RLTrainer
19+ from forge .actors .trainer import TitanTrainer
2020from forge .controller .provisioner import init_provisioner
2121
2222from forge .controller .service .service import uuid
5050TEST_DCP_DIR = "test_dcp_tmp"
5151
5252
53- class MockRLTrainer ( RLTrainer ):
53+ class MockTitanTrainer ( TitanTrainer ):
5454 @endpoint
5555 async def zero_out_model_states (self ):
5656 """This simply sets all model weights to zero."""
@@ -59,7 +59,7 @@ async def zero_out_model_states(self):
5959 for k in sd .keys ():
6060 if not torch .is_floating_point (sd [k ]):
6161 logger .info (
62- f"[MockRLTrainer ] zero_out_model_states(): skipping non-float param { k } "
62+ f"[MockTitanTrainer ] zero_out_model_states(): skipping non-float param { k } "
6363 )
6464 continue
6565 sd [k ] *= 0.0
@@ -199,22 +199,22 @@ async def _setup_and_teardown(request):
199199 )
200200 await ts .initialize (strategy = ts .ControllerStorageVolumes ())
201201
202- policy , rl_trainer = await asyncio .gather (
202+ policy , titan_trainer = await asyncio .gather (
203203 * [
204204 Generator .options (** services_policy_cfg ).as_service (** cfg .policy ),
205- MockRLTrainer .options (** cfg .actors .trainer ).as_actor (** trainer_cfg ),
205+ MockTitanTrainer .options (** cfg .actors .trainer ).as_actor (** trainer_cfg ),
206206 ]
207207 )
208208
209- yield policy , rl_trainer
209+ yield policy , titan_trainer
210210
211211 # ---- teardown ---- #
212212 logger .info ("Shutting down services and cleaning up DCP directory.." )
213213
214214 await asyncio .gather (
215215 policy .shutdown (),
216216 ts .shutdown (),
217- RLTrainer .shutdown (rl_trainer ),
217+ TitanTrainer .shutdown (titan_trainer ),
218218 )
219219
220220 # Cleanup DCP directory
@@ -235,7 +235,7 @@ class TestWeightSync:
235235 @requires_cuda
236236 async def test_sanity_check (self , _setup_and_teardown ):
237237 """
238- Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
238+ Sanity check for weight sync sharding between TitanTrainer and Policy for a given model config.
239239
240240 The check performs the following steps:
241241 - Initialize trainer and push weights v0 (original huggingface ckpt)
@@ -245,15 +245,15 @@ async def test_sanity_check(self, _setup_and_teardown):
245245
246246 """
247247
248- policy , rl_trainer = _setup_and_teardown
248+ policy , titan_trainer = _setup_and_teardown
249249
250250 v0 = uuid .uuid4 ().int
251251 v1 = v0 + 1
252252
253- await rl_trainer .push_weights .call (policy_version = v0 )
253+ await titan_trainer .push_weights .call (policy_version = v0 )
254254 # Setting everything to zero
255- await rl_trainer .zero_out_model_states .call ()
256- await rl_trainer .push_weights .call (policy_version = v1 )
255+ await titan_trainer .zero_out_model_states .call ()
256+ await titan_trainer .push_weights .call (policy_version = v1 )
257257 await policy .save_model_params .fanout ()
258258
259259 # Sanity check that before update all the tests pass
0 commit comments