Skip to content

Commit e973d93

Browse files
committed
amend
1 parent 8b9508f commit e973d93

File tree

8 files changed

+236
-450
lines changed

8 files changed

+236
-450
lines changed

test/test_collector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3938,13 +3938,12 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
39383938
def all_worker_ids(self) -> list[int] | list[torch.device]:
39393939
return list(range(self.num_workers))
39403940

3941-
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
39423941
@pytest.mark.skipif(not _has_gym, reason="requires gym")
39433942
@pytest.mark.parametrize(
39443943
"weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"]
39453944
)
39463945
def test_weight_update(self, weight_updater):
3947-
device = "cuda:0"
3946+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
39483947
env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu")
39493948
policy_factory = lambda: TensorDictModule(
39503949
nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"]

torchrl/collectors/_multi_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,9 @@ def _run_processes(self) -> None:
835835
# can be initialized here since all required resources exist
836836
if self._weight_sync_schemes:
837837
for model_id, scheme in self._weight_sync_schemes.items():
838-
if hasattr(scheme, "init_on_sender"):
838+
if not scheme.initialized_on_sender:
839839
scheme.init_on_sender(model_id=model_id, context=self)
840-
self._weight_senders[model_id] = scheme.get_sender()
840+
self._weight_senders[model_id] = scheme.get_sender()
841841

842842
# Create a policy on the right device
843843
policy_factory = self.policy_factory

torchrl/collectors/_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def _make_policy_factory(
4040
if weight_sync_scheme is not None:
4141
# Initialize the receiver on the worker side
4242
weight_sync_scheme.init_on_receiver(
43-
model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe
43+
model=policy,
44+
model_id="policy",
45+
worker_idx=worker_idx,
4446
)
4547
# Get the receiver and synchronize initial weights
4648
receiver = weight_sync_scheme.get_receiver()

0 commit comments

Comments
 (0)