Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14014,6 +14014,117 @@ def test_transform_compose(self, reward_key, num_rewards, weights):
)
torch.testing.assert_close(td[out_keys], expected)

def test_compose_with_reward_scaling(self):
"""Test that LineariseRewards properly registers output keys for use in Compose.

This test reproduces the issue from GitHub #3237 where LineariseRewards
does not register its output keys in the spec, causing subsequent transforms
to fail during initialization.
"""
# Create a simple env with multi-objective rewards
env = self._DummyMultiObjectiveEnv(num_rewards=3)

# Create a composed transform with LineariseRewards and RewardScaling
# This should work without KeyError since transform_output_spec properly validates
transform = Compose(
LineariseRewards(
in_keys=[("reward",)],
out_keys=[
(
"nested",
"scalar_reward",
)
],
weights=[1.0, 2.0, 3.0],
),
RewardScaling(
in_keys=[
(
"nested",
"scalar_reward",
)
],
loc=0.0,
scale=10.0,
),
)

# Apply transform to environment
transformed_env = TransformedEnv(env, transform)

# Check that specs are valid
check_env_specs(transformed_env)

# Verify the transform works correctly
rollout = transformed_env.rollout(5)
assert ("next", "nested", "scalar_reward") in rollout.keys(True)
assert rollout[("next", "nested", "scalar_reward")].shape[-1] == 1

def test_compose_with_nested_keys(self):
"""Test LineariseRewards with nested keys as described in GitHub #3237."""
# Create a dummy env that produces nested rewards
class _NestedRewardEnv(EnvBase):
def __init__(self):
super().__init__()
self.observation_spec = Composite(
observation=UnboundedContinuous((*self.batch_size, 3))
)
self.action_spec = Categorical(
2, (*self.batch_size, 1), dtype=torch.bool
)
self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
self.full_done_spec["truncated"] = self.full_done_spec[
"terminated"
].clone()
# Nested reward spec
self.reward_spec = Composite(
agent1=Composite(
reward_vec=UnboundedContinuous(*self.batch_size, 2)
)
)

def _reset(self, tensordict: TensorDict) -> TensorDict:
return self.observation_spec.sample()

def _step(self, tensordict: TensorDict) -> TensorDict:
return TensorDict(
{
("observation"): self.observation_spec["observation"].sample(),
("done"): False,
("terminated"): False,
("agent1", "reward_vec"): torch.randn(2),
}
)

def _set_seed(self, seed: int | None = None) -> None:
pass

env = _NestedRewardEnv()

# This is the exact scenario from the GitHub issue
transform = Compose(
transforms=[
LineariseRewards(
in_keys=[("agent1", "reward_vec")],
out_keys=[("agent1", "weighted_reward")],
),
RewardScaling(
in_keys=[("agent1", "weighted_reward")], loc=0.0, scale=2.0
),
],
)

# This should work without KeyError
transformed_env = TransformedEnv(env, transform)

# Check that specs are valid
check_env_specs(transformed_env)

# Verify the transform works correctly
rollout = transformed_env.rollout(5)
assert ("next", "agent1", "weighted_reward") in rollout.keys(True)
assert rollout[("next", "agent1", "weighted_reward")].shape[-1] == 1

class _DummyMultiObjectiveEnv(EnvBase):
"""A dummy multi-objective environment."""

Expand Down
6 changes: 2 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,11 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
in_keys = {unravel_key(k) for k in self.in_keys}
for key in out_keys - in_keys:
if unravel_key(key) not in output_spec_keys:
warnings.warn(
raise KeyError(
f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). "
f"Every new entry in the tensordict resulting from a call to a transform must be "
f"registered in the specs for torchrl rollouts to be consistently built. "
f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
"This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.",
category=FutureWarning,
f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly."
)
return output_spec

Expand Down