Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,139 @@ def env_fn():
total_frames=frames_per_batch * 100,
)

class FixedIDEnv(EnvBase):
"""
A simple mock environment that returns a fixed ID as its sole observation.

This environment is designed to test MultiSyncDataCollector ordering.
Each environment instance is initialized with a unique env_id, which it
returns as the observation at every step.
"""

def __init__(self, env_id: int, max_steps: int = 10, **kwargs):
"""
Args:
env_id: The ID to return as observation. This will be returned as a tensor.
max_steps: Maximum number of steps before the environment terminates.
"""
super().__init__(device="cpu", batch_size=torch.Size([]))
self.env_id = env_id
self.max_steps = max_steps
self._step_count = 0

# Define specs
self.observation_spec = Composite(
observation=Unbounded(shape=(1,), dtype=torch.float32)
)
self.action_spec = Composite(
action=Unbounded(shape=(1,), dtype=torch.float32)
)
self.reward_spec = Composite(
reward=Unbounded(shape=(1,), dtype=torch.float32)
)
self.done_spec = Composite(
done=Unbounded(shape=(1,), dtype=torch.bool),
terminated=Unbounded(shape=(1,), dtype=torch.bool),
truncated=Unbounded(shape=(1,), dtype=torch.bool),
)

def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
"""Reset the environment and return initial observation."""
# Add random sleep to simulate real-world timing variations
# This helps test that the collector properly handles different reset times
time.sleep(torch.rand(1).item() * 0.01) # Random sleep up to 10ms

self._step_count = 0
return TensorDict(
{
"observation": torch.tensor(
[float(self.env_id)], dtype=torch.float32
),
"done": torch.tensor([False], dtype=torch.bool),
"terminated": torch.tensor([False], dtype=torch.bool),
"truncated": torch.tensor([False], dtype=torch.bool),
},
batch_size=self.batch_size,
)

def _step(self, tensordict: TensorDict) -> TensorDict:
"""Execute one step and return the env_id as observation."""
self._step_count += 1
done = self._step_count >= self.max_steps

return TensorDict(
{
"observation": torch.tensor(
[float(self.env_id)], dtype=torch.float32
),
"reward": torch.tensor([1.0], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"terminated": torch.tensor([done], dtype=torch.bool),
"truncated": torch.tensor([False], dtype=torch.bool),
},
batch_size=self.batch_size,
)

def _set_seed(self, seed: int | None) -> int | None:
"""Set the seed for reproducibility."""
if seed is not None:
torch.manual_seed(seed)
return seed

@pytest.mark.parametrize("num_envs", [8])
def test_multi_sync_data_collector_ordering(self, num_envs: int):
"""
Test that MultiSyncDataCollector returns data in the correct order.

We create num_envs environments, each returning its env_id as the observation.
After collection, we verify that the observations correspond to the correct env_ids in order
"""
frames_per_batch = num_envs * 5 # Collect 5 steps per environment

# Create environment factories using partial - one for each env_id
# This pattern mirrors CrossPlayEvaluator._rollout usage
env_factories = [
functools.partial(self.FixedIDEnv, env_id=i, max_steps=10)
for i in range(num_envs)
]

# Create policy factories using partial
policy = ParametricPolicy()

# Initialize MultiSyncDataCollector
collector = MultiSyncDataCollector(
create_env_fn=env_factories,
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=frames_per_batch,
device="cpu",
)

# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
for env_idx in range(num_envs):
env_data = batch[env_idx]
observations = env_data["observation"]

# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"

# Only process the first batch
break

collector.shutdown()


class TestCollectorDevices:
class DeviceLessEnv(EnvBase):
Expand Down
60 changes: 22 additions & 38 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3760,8 +3760,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
cat_results = self.cat_results
if cat_results is None:
cat_results = "stack"

self.buffers = {}
self.buffers = [None for _ in range(self.num_workers)]
dones = [False for _ in range(self.num_workers)]
workers_frames = [0 for _ in range(self.num_workers)]
same_device = None
Expand All @@ -3781,7 +3780,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
msg = "continue_random"
else:
msg = "continue"
# Debug: sending 'continue'
self.pipes[idx].send((None, msg))

self._iter += 1
Expand Down Expand Up @@ -3844,16 +3842,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
if preempt:
# mask buffers if cat, and create a mask if stack
if cat_results != "stack":
buffers = {}
for worker_idx, buffer in self.buffers.items():
buffers = [None] * self.num_workers
for idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to filter out buffers who did not return their experience: I use the enumerate(filter(None.__ne__, self.buffers)) idiom to make this compact and hopefully readable; I'm open to better ideas

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you define idx which was defined earlier (LoC 3829 or 3840). It should be worker_idx I believe
See my comment below

valid = buffer.get(("collector", "traj_ids")) != -1
if valid.ndim > 2:
valid = valid.flatten(0, -2)
if valid.ndim == 2:
valid = valid.any(0)
buffers[worker_idx] = buffer[..., valid]
buffers[idx] = buffer[..., valid]
else:
for buffer in self.buffers.values():
for buffer in filter(None.__ne__, self.buffers):
with buffer.unlock_():
buffer.set(
("collector", "mask"),
Expand All @@ -3863,11 +3861,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
else:
buffers = self.buffers

# Skip frame counting if this worker didn't send data this iteration
# (happens when reusing buffers or on first iteration with some workers)
if idx not in buffers:
continue

Comment on lines -3866 to -3870
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am puzzled by this code, and I miss where it could happen that idx is defined but the related buffer be None

An equivalent code here would be if buffers[idx] is None: continue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens during preemption: if we say that we're ok with 80% of the data it could be that we don't have data for one of the workers and we just return whatever we have at this stage.

workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()

if workers_frames[idx] >= self.total_frames:
Expand All @@ -3876,17 +3869,15 @@ def iterator(self) -> Iterator[TensorDictBase]:
if self.replay_buffer is not None:
yield
self._frames += sum(
[
self.frames_per_batch_worker(worker_idx)
for worker_idx in range(self.num_workers)
]
self.frames_per_batch_worker(worker_idx)
for worker_idx in range(self.num_workers)
)
continue

# we have to correct the traj_ids to make sure that they don't overlap
# We can count the number of frames collected for free in this loop
n_collected = 0
for idx in buffers.keys():
for idx in range(self.num_workers):
buffer = buffers[idx]
traj_ids = buffer.get(("collector", "traj_ids"))
if preempt:
Expand All @@ -3901,7 +3892,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
if same_device is None:
prev_device = None
same_device = True
for item in self.buffers.values():
for item in filter(None.__ne__, self.buffers):
if prev_device is None:
prev_device = item.device
else:
Expand All @@ -3912,33 +3903,30 @@ def iterator(self) -> Iterator[TensorDictBase]:
torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
)
if same_device:
self.out_buffer = stack(list(buffers.values()), 0)
self.out_buffer = stack(
[item for item in buffers if item is not None], 0
)
else:
self.out_buffer = stack(
[item.cpu() for item in buffers.values()], 0
[item.cpu() for item in buffers if item is not None], 0
)
else:
if self._use_buffers is None:
torchrl_logger.warning(
"use_buffer not specified and not yet inferred from data, assuming `True`."
)
torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.")
elif not self._use_buffers:
raise RuntimeError(
"Cannot concatenate results with use_buffers=False"
)
raise RuntimeError("Cannot concatenate results with use_buffers=False")
try:
if same_device:
self.out_buffer = torch.cat(list(buffers.values()), cat_results)
self.out_buffer = torch.cat(
[item for item in buffers if item is not None], cat_results
)
else:
self.out_buffer = torch.cat(
[item.cpu() for item in buffers.values()], cat_results
[item.cpu() for item in buffers if item is not None],
cat_results,
)
except RuntimeError as err:
if (
preempt
and cat_results != -1
and "Sizes of tensors must match" in str(err)
):
if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err):
raise RuntimeError(
"The value provided to cat_results isn't compatible with the collectors outputs. "
"Consider using `cat_results=-1`."
Expand All @@ -3956,11 +3944,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
self._frames += n_collected

if self.postprocs:
self.postprocs = (
self.postprocs.to(out.device)
if hasattr(self.postprocs, "to")
else self.postprocs
)
self.postprocs = self.postprocs.to(out.device) if hasattr(self.postprocs, "to") else self.postprocs
out = self.postprocs(out)
if self._exclude_private_keys:
excluded_keys = [key for key in out.keys() if key.startswith("_")]
Expand Down
Loading