Skip to content

Commit 3c12abb

Browse files
Mayankm96ClemensSchwarke
authored andcommitted
Renames observation types to observation sets
Approved-by: Clemens Schwarke
1 parent 332661e commit 3c12abb

File tree

4 files changed

+95
-65
lines changed

4 files changed

+95
-65
lines changed

rsl_rl/env/vec_env.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,32 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
6767
dones (torch.Tensor): Done flags from the environment. Shape: (num_envs,)
6868
extras (dict): Extra information from the environment.
6969
70-
Notes:
71-
Observations:
72-
The observations TensorDict usually contains multiple observation groups. The `obs_groups` dictionary of the
73-
runner configuration specifies which observation groups are used for which purpose, i.e., it maps the
74-
available observation groups to observation types. The observation types (keys of the `obs_groups`
75-
dictionary) currently used by rsl_rl are:
76-
- "policy": Specified observation groups are used as input to the policy/actor network.
77-
- "critic": Specified observation groups are used as input to the critic network.
78-
- "teacher": Specified observation groups are used as input to the teacher network.
79-
- "rnd_state": Specified observation groups are used as input to the RND network.
80-
The "policy" observation type is always required. The other observation types are optional and if not
81-
provided, the observation groups specified for the "policy" observation type are used.
82-
83-
Extras:
70+
Observations:
71+
72+
The observations TensorDict usually contains multiple observation groups. The `obs_groups`
73+
dictionary of the runner configuration specifies which observation groups are used for which
74+
purpose, i.e., it maps the available observation groups to observation sets. The observation sets
75+
(keys of the `obs_groups` dictionary) currently used by rsl_rl are:
76+
77+
- "policy": Specified observation groups are used as input to the actor/student network.
78+
- "critic": Specified observation groups are used as input to the critic network.
79+
- "teacher": Specified observation groups are used as input to the teacher network.
80+
- "rnd_state": Specified observation groups are used as input to the RND network.
81+
82+
Incomplete or incorrect configurations are handled in the `resolve_obs_groups()` function in
83+
`rsl_rl/utils/utils.py`.
84+
85+
Extras:
86+
8487
The extras dictionary includes metrics such as the episode reward, episode length, etc. The following
8588
dictionary keys are used by rsl_rl:
86-
- "time_outs" (torch.Tensor): Timeouts for the environments. These correspond to terminations that
87-
happen due to time limits and not due to the environment reaching a terminal state. This is useful
88-
for environments that have a fixed episode length.
8989
90-
- "log" (dict[str, float | torch.Tensor]): Additional information for logging and debugging purposes.
91-
The key should be a string and start with "/" for namespacing. The value can be a scalar or a
92-
tensor. If it is a tensor, the mean of the tensor is used for logging.
90+
- "time_outs" (torch.Tensor): Timeouts for the environments. These correspond to terminations that
91+
happen due to time limits and not due to the environment reaching a terminal state. This is useful
92+
for environments that have a fixed episode length.
93+
94+
- "log" (dict[str, float | torch.Tensor]): Additional information for logging and debugging purposes.
95+
The key should be a string and start with "/" for namespacing. The value can be a scalar or a
96+
tensor. If it is a tensor, the mean of the tensor is used for logging.
9397
"""
9498
raise NotImplementedError

rsl_rl/runners/distillation_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from rsl_rl.env import VecEnv
1616
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1717
from rsl_rl.runners import OnPolicyRunner
18-
from rsl_rl.utils import resolve_obs_types, store_code_state
18+
from rsl_rl.utils import resolve_obs_groups, store_code_state
1919

2020

2121
class DistillationRunner(OnPolicyRunner):
@@ -37,7 +37,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev
3737

3838
# query observations from environment for algorithm construction
3939
obs = self.env.get_observations()
40-
self.cfg["obs_groups"] = resolve_obs_types(obs, self.cfg["obs_groups"], default_types=["teacher"])
40+
self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets=["teacher"])
4141

4242
# create the algorithm
4343
self.alg = self._construct_algorithm(obs)

rsl_rl/runners/on_policy_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from rsl_rl.algorithms import PPO
1717
from rsl_rl.env import VecEnv
1818
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config
19-
from rsl_rl.utils import resolve_obs_types, store_code_state
19+
from rsl_rl.utils import resolve_obs_groups, store_code_state
2020

2121

2222
class OnPolicyRunner:
@@ -38,10 +38,10 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev
3838

3939
# query observations from environment for algorithm construction
4040
obs = self.env.get_observations()
41-
default_types = ["critic"]
41+
default_sets = ["critic"]
4242
if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None:
43-
default_types.append("rnd_state")
44-
self.cfg["obs_groups"] = resolve_obs_types(obs, self.cfg["obs_groups"], default_types)
43+
default_sets.append("rnd_state")
44+
self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets)
4545

4646
# create the algorithm
4747
self.alg = self._construct_algorithm(obs)

rsl_rl/utils/utils.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import os
1111
import pathlib
1212
import torch
13+
import warnings
1314
from tensordict import TensorDict
1415
from typing import Callable
1516

@@ -198,81 +199,106 @@ def string_to_callable(name: str) -> Callable:
198199
raise ValueError(msg)
199200

200201

201-
def resolve_obs_types(
202-
obs: TensorDict, obs_groups: dict[str, list[str]], default_types: list[str]
202+
def resolve_obs_groups(
203+
obs: TensorDict, obs_groups: dict[str, list[str]], default_sets: list[str]
203204
) -> dict[str, list[str]]:
204-
"""Validates the observation types configuration and defaults missing observation types.
205+
"""Validates the observation configuration and defaults missing observation sets.
205206
206207
The input is an observation dictionary `obs` containing observation groups and a configuration dictionary
207-
`obs_groups` where the keys are the observation types and the values are lists of observation groups. A detailed
208-
description is found in `rsl_rl/env/vec_env.py`.
208+
`obs_groups` where the keys are the observation sets and the values are lists of observation groups.
209209
210210
The configuration dictionary could for example look like:
211211
{
212212
"policy": ["group_1", "group_2"],
213213
"critic": ["group_1", "group_3"]
214214
}
215215
216-
This means that the 'policy' type will contain the observations "group_1" and "group_2" and the 'critic' type will
217-
contain the observations "group_1" and "group_3". The function will check that all the observations in the 'policy'
218-
and 'critic' groups are present in the observation dictionary from the environment.
216+
This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the
217+
'critic' observation set will contain the observations "group_1" and "group_3". This function will check that all
218+
the observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the
219+
environment.
219220
220-
Additionally, if one of the `default_types`, e.g. "critic", is not present in the configuration dictionary,
221+
Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary,
221222
this function will:
222-
1. Check if a group with the same name exists in the observations and assign this group to the observation type.
223-
2. If not, it will assign the observations from the 'policy' type to the observation type.
223+
224+
1. Check if a group with the same name exists in the observations and assign this group to the observation set.
225+
2. If 1. fails, it will assign the observations from the 'policy' observation set to the default observation set.
224226
225227
Args:
226228
obs: Observations from the environment in the form of a dictionary.
227-
obs_groups: Observation types configuration.
228-
default_types: Reserved type names used by the algorithm (besides 'policy').
229+
obs_groups: Observation sets configuration.
230+
default_sets: Reserved observation set names used by the algorithm (besides 'policy').
229231
If not provided in 'obs_groups', a default behavior gets triggered.
230232
231233
Returns:
232234
The resolved observation groups.
233235
234236
Raises:
235-
ValueError: If the "policy" observation type is not present in the provided observation groups configuration.
236-
ValueError: If any observation type is an empty list.
237-
ValueError: If any observation type contains an observation term that is not present in the observations.
237+
ValueError: If any observation set is an empty list.
238+
ValueError: If any observation set contains an observation term that is not present in the observations.
238239
"""
239-
# check if policy observation type exists
240+
# check if policy observation set exists
240241
if "policy" not in obs_groups.keys():
241-
raise ValueError(
242-
"The observation type configuration dictionary must contain the 'policy' key."
243-
f" Found keys: {list(obs_groups.keys())}"
244-
)
242+
if "policy" in obs:
243+
obs_groups["policy"] = ["policy"]
244+
warnings.warn(
245+
"The observation configuration dictionary 'obs_groups' must contain the 'policy' key."
246+
" As an observation group with the name 'policy' was found, this is assumed to be the observation set."
247+
" Consider adding the 'policy' key to the 'obs_groups' dictionary for clarity."
248+
" This behavior will be removed in a future version."
249+
)
250+
else:
251+
raise ValueError(
252+
"The observation configuration dictionary 'obs_groups' must contain the 'policy' key."
253+
f" Found keys: {list(obs_groups.keys())}"
254+
)
245255

246-
# check all observation types for valid observation groups
247-
for type, groups in obs_groups.items():
256+
# check all observation sets for valid observation groups
257+
for set_name, groups in obs_groups.items():
248258
# check if the list is empty
249259
if len(groups) == 0:
250-
msg = f"The '{type}' key in the 'obs_groups' dictionary can not be an empty list."
251-
if type in default_types:
252-
if type not in obs:
253-
msg += " Consider removing the key to default to the observations used for the 'policy' type."
260+
msg = f"The '{set_name}' key in the 'obs_groups' dictionary can not be an empty list."
261+
if set_name in default_sets:
262+
if set_name not in obs:
263+
msg += " Consider removing the key to default to the observations used for the 'policy' set."
254264
else:
255-
msg += f" Consider removing the key to default to the observation '{type}' from the environment."
265+
msg += (
266+
f" Consider removing the key to default to the observation '{set_name}' from the environment."
267+
)
256268
raise ValueError(msg)
257269
# check groups exist inside the observations from the environment
258270
for group in groups:
259271
if group not in obs:
260272
raise ValueError(
261-
f"Observation '{group}' in observation type '{type}' not found in the observations from the"
273+
f"Observation '{group}' in observation set '{set_name}' not found in the observations from the"
262274
f" environment. Available observations from the environment: {list(obs.keys())}"
263275
)
264276

265-
# fill missing observation types
266-
for default_type in default_types:
267-
if default_type not in obs_groups.keys():
268-
if default_type in obs:
269-
obs_groups[default_type] = [default_type]
277+
# fill missing observation sets
278+
for default_set_name in default_sets:
279+
if default_set_name not in obs_groups.keys():
280+
if default_set_name in obs:
281+
obs_groups[default_set_name] = [default_set_name]
282+
warnings.warn(
283+
f"The observation configuration dictionary 'obs_groups' must contain the '{default_set_name}' key."
284+
f" As an observation group with the name '{default_set_name}' was found, this is assumed to be the"
285+
f" observation set. Consider adding the '{default_set_name}' key to the 'obs_groups' dictionary for"
286+
" clarity. This behavior will be removed in a future version."
287+
)
270288
else:
271-
obs_groups[default_type] = obs_groups["policy"].copy()
289+
obs_groups[default_set_name] = obs_groups["policy"].copy()
290+
warnings.warn(
291+
f"The observation configuration dictionary 'obs_groups' must contain the '{default_set_name}' key."
292+
f" As the configuration for '{default_set_name}' is missing, the observations from the 'policy' set"
293+
f" are used. Consider adding the '{default_set_name}' key to the 'obs_groups' dictionary for"
294+
" clarity. This behavior will be removed in a future version."
295+
)
272296

273-
# print the final parsed observation types
274-
print("Resolved observation types: ")
275-
for type, groups in obs_groups.items():
276-
print("\t", type, ": ", groups)
297+
# print the final parsed observation sets
298+
print("-" * 80)
299+
print("Resolved observation sets: ")
300+
for set_name, groups in obs_groups.items():
301+
print("\t", set_name, ": ", groups)
302+
print("-" * 80)
277303

278304
return obs_groups

0 commit comments

Comments
 (0)