diff --git a/test/test_transforms.py b/test/test_transforms.py index 567c0995d20..e710a93781b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -14014,6 +14014,117 @@ def test_transform_compose(self, reward_key, num_rewards, weights): ) torch.testing.assert_close(td[out_keys], expected) + def test_compose_with_reward_scaling(self): + """Test that LineariseRewards properly registers output keys for use in Compose. + + This test reproduces the issue from GitHub #3237 where LineariseRewards + does not register its output keys in the spec, causing subsequent transforms + to fail during initialization. + """ + # Create a simple env with multi-objective rewards + env = self._DummyMultiObjectiveEnv(num_rewards=3) + + # Create a composed transform with LineariseRewards and RewardScaling + # This should work without KeyError since transform_output_spec properly validates + transform = Compose( + LineariseRewards( + in_keys=[("reward",)], + out_keys=[ + ( + "nested", + "scalar_reward", + ) + ], + weights=[1.0, 2.0, 3.0], + ), + RewardScaling( + in_keys=[ + ( + "nested", + "scalar_reward", + ) + ], + loc=0.0, + scale=10.0, + ), + ) + + # Apply transform to environment + transformed_env = TransformedEnv(env, transform) + + # Check that specs are valid + check_env_specs(transformed_env) + + # Verify the transform works correctly + rollout = transformed_env.rollout(5) + assert ("next", "nested", "scalar_reward") in rollout.keys(True) + assert rollout[("next", "nested", "scalar_reward")].shape[-1] == 1 + + def test_compose_with_nested_keys(self): + """Test LineariseRewards with nested keys as described in GitHub #3237.""" + # Create a dummy env that produces nested rewards + class _NestedRewardEnv(EnvBase): + def __init__(self): + super().__init__() + self.observation_spec = Composite( + observation=UnboundedContinuous((*self.batch_size, 3)) + ) + self.action_spec = Categorical( + 2, (*self.batch_size, 1), dtype=torch.bool + ) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) + self.full_done_spec["truncated"] = self.full_done_spec[ + "terminated" + ].clone() + # Nested reward spec + self.reward_spec = Composite( + agent1=Composite( + reward_vec=UnboundedContinuous(*self.batch_size, 2) + ) + ) + + def _reset(self, tensordict: TensorDict) -> TensorDict: + return self.observation_spec.sample() + + def _step(self, tensordict: TensorDict) -> TensorDict: + return TensorDict( + { + ("observation"): self.observation_spec["observation"].sample(), + ("done"): False, + ("terminated"): False, + ("agent1", "reward_vec"): torch.randn(2), + } + ) + + def _set_seed(self, seed: int | None = None) -> None: + pass + + env = _NestedRewardEnv() + + # This is the exact scenario from the GitHub issue + transform = Compose( + transforms=[ + LineariseRewards( + in_keys=[("agent1", "reward_vec")], + out_keys=[("agent1", "weighted_reward")], + ), + RewardScaling( + in_keys=[("agent1", "weighted_reward")], loc=0.0, scale=2.0 + ), + ], + ) + + # This should work without KeyError + transformed_env = TransformedEnv(env, transform) + + # Check that specs are valid + check_env_specs(transformed_env) + + # Verify the transform works correctly + rollout = transformed_env.rollout(5) + assert ("next", "agent1", "weighted_reward") in rollout.keys(True) + assert rollout[("next", "agent1", "weighted_reward")].shape[-1] == 1 + class _DummyMultiObjectiveEnv(EnvBase): """A dummy multi-objective environment.""" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ca9ab70f184..227e0e038f1 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -592,13 +592,11 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: in_keys = {unravel_key(k) for k in self.in_keys} for key in out_keys - in_keys: if unravel_key(key) not in output_spec_keys: - warnings.warn( + raise KeyError( f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). " f"Every new entry in the tensordict resulting from a call to a transform must be " f"registered in the specs for torchrl rollouts to be consistently built. " - f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. " - "This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.", - category=FutureWarning, + f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly." ) return output_spec