diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..c0058068ac2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1717,6 +1717,139 @@ def env_fn(): total_frames=frames_per_batch * 100, ) + class FixedIDEnv(EnvBase): + """ + A simple mock environment that returns a fixed ID as its sole observation. + + This environment is designed to test MultiSyncDataCollector ordering. + Each environment instance is initialized with a unique env_id, which it + returns as the observation at every step. + """ + + def __init__(self, env_id: int, max_steps: int = 10, **kwargs): + """ + Args: + env_id: The ID to return as observation. This will be returned as a tensor. + max_steps: Maximum number of steps before the environment terminates. + """ + super().__init__(device="cpu", batch_size=torch.Size([])) + self.env_id = env_id + self.max_steps = max_steps + self._step_count = 0 + + # Define specs + self.observation_spec = Composite( + observation=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.action_spec = Composite( + action=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.reward_spec = Composite( + reward=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.done_spec = Composite( + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), + truncated=Unbounded(shape=(1,), dtype=torch.bool), + ) + + def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict: + """Reset the environment and return initial observation.""" + # Add random sleep to simulate real-world timing variations + # This helps test that the collector properly handles different reset times + time.sleep(torch.rand(1).item() * 0.01) # Random sleep up to 10ms + + self._step_count = 0 + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "done": torch.tensor([False], dtype=torch.bool), + "terminated": torch.tensor([False], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _step(self, tensordict: TensorDict) -> TensorDict: + """Execute one step and return the env_id as observation.""" + self._step_count += 1 + done = self._step_count >= self.max_steps + + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "reward": torch.tensor([1.0], dtype=torch.float32), + "done": torch.tensor([done], dtype=torch.bool), + "terminated": torch.tensor([done], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed: int | None) -> int | None: + """Set the seed for reproducibility.""" + if seed is not None: + torch.manual_seed(seed) + return seed + + @pytest.mark.parametrize("num_envs", [8]) + def test_multi_sync_data_collector_ordering(self, num_envs: int): + """ + Test that MultiSyncDataCollector returns data in the correct order. + + We create num_envs environments, each returning its env_id as the observation. + After collection, we verify that the observations correspond to the correct env_ids in order + """ + frames_per_batch = num_envs * 5 # Collect 5 steps per environment + + # Create environment factories using partial - one for each env_id + # This pattern mirrors CrossPlayEvaluator._rollout usage + env_factories = [ + functools.partial(self.FixedIDEnv, env_id=i, max_steps=10) + for i in range(num_envs) + ] + + # Create policy factories using partial + policy = ParametricPolicy() + + # Initialize MultiSyncDataCollector + collector = MultiSyncDataCollector( + create_env_fn=env_factories, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=frames_per_batch, + device="cpu", + ) + + # Collect one batch + for batch in collector: + # Verify that each environment's observations match its env_id + # batch has shape [num_envs, frames_per_env] + for env_idx in range(num_envs): + env_data = batch[env_idx] + observations = env_data["observation"] + + # All observations from this environment should equal its env_id + expected_id = float(env_idx) + actual_ids = observations.flatten().unique() + + assert len(actual_ids) == 1, ( + f"Env {env_idx} should only produce observations with value {expected_id}, " + f"but got {actual_ids.tolist()}" + ) + assert ( + actual_ids[0].item() == expected_id + ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" + + # Only process the first batch + break + + collector.shutdown() + class TestCollectorDevices: class DeviceLessEnv(EnvBase): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b7be73d243f..06dc2c1a95e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3760,8 +3760,7 @@ def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: cat_results = "stack" - - self.buffers = {} + self.buffers = [None for _ in range(self.num_workers)] dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] same_device = None @@ -3781,7 +3780,6 @@ def iterator(self) -> Iterator[TensorDictBase]: msg = "continue_random" else: msg = "continue" - # Debug: sending 'continue' self.pipes[idx].send((None, msg)) self._iter += 1 @@ -3844,8 +3842,10 @@ def iterator(self) -> Iterator[TensorDictBase]: if preempt: # mask buffers if cat, and create a mask if stack if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): + buffers = [None] * self.num_workers + for worker_idx, buffer in enumerate( + filter(lambda x: x is not None, self.buffers) + ): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) @@ -3853,7 +3853,7 @@ def iterator(self) -> Iterator[TensorDictBase]: valid = valid.any(0) buffers[worker_idx] = buffer[..., valid] else: - for buffer in self.buffers.values(): + for buffer in filter(lambda x: x is not None, self.buffers): with buffer.unlock_(): buffer.set( ("collector", "mask"), @@ -3865,7 +3865,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # Skip frame counting if this worker didn't send data this iteration # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: + if self.buffers[idx] is None: continue workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() @@ -3876,17 +3876,15 @@ def iterator(self) -> Iterator[TensorDictBase]: if self.replay_buffer is not None: yield self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] + self.frames_per_batch_worker(worker_idx) + for worker_idx in range(self.num_workers) ) continue # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx in buffers.keys(): + for idx in range(self.num_workers): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: @@ -3901,7 +3899,7 @@ def iterator(self) -> Iterator[TensorDictBase]: if same_device is None: prev_device = None same_device = True - for item in self.buffers.values(): + for item in filter(lambda x: x is not None, self.buffers): if prev_device is None: prev_device = item.device else: @@ -3912,10 +3910,12 @@ def iterator(self) -> Iterator[TensorDictBase]: torch.stack if self._use_buffers else TensorDict.maybe_dense_stack ) if same_device: - self.out_buffer = stack(list(buffers.values()), 0) + self.out_buffer = stack( + [item for item in buffers if item is not None], 0 + ) else: self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 + [item.cpu() for item in buffers if item is not None], 0 ) else: if self._use_buffers is None: @@ -3928,10 +3928,13 @@ def iterator(self) -> Iterator[TensorDictBase]: ) try: if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) + self.out_buffer = torch.cat( + [item for item in buffers if item is not None], cat_results + ) else: self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results + [item.cpu() for item in buffers if item is not None], + cat_results, ) except RuntimeError as err: if (