|
10 | 10 | import os
|
11 | 11 | import pathlib
|
12 | 12 | import torch
|
| 13 | +import warnings |
13 | 14 | from tensordict import TensorDict
|
14 | 15 | from typing import Callable
|
15 | 16 |
|
@@ -198,81 +199,106 @@ def string_to_callable(name: str) -> Callable:
|
198 | 199 | raise ValueError(msg)
|
199 | 200 |
|
200 | 201 |
|
201 |
| -def resolve_obs_types( |
202 |
| - obs: TensorDict, obs_groups: dict[str, list[str]], default_types: list[str] |
| 202 | +def resolve_obs_groups( |
| 203 | + obs: TensorDict, obs_groups: dict[str, list[str]], default_sets: list[str] |
203 | 204 | ) -> dict[str, list[str]]:
|
204 |
| - """Validates the observation types configuration and defaults missing observation types. |
| 205 | + """Validates the observation configuration and defaults missing observation sets. |
205 | 206 |
|
206 | 207 | The input is an observation dictionary `obs` containing observation groups and a configuration dictionary
|
207 |
| - `obs_groups` where the keys are the observation types and the values are lists of observation groups. A detailed |
208 |
| - description is found in `rsl_rl/env/vec_env.py`. |
| 208 | + `obs_groups` where the keys are the observation sets and the values are lists of observation groups. |
209 | 209 |
|
210 | 210 | The configuration dictionary could for example look like:
|
211 | 211 | {
|
212 | 212 | "policy": ["group_1", "group_2"],
|
213 | 213 | "critic": ["group_1", "group_3"]
|
214 | 214 | }
|
215 | 215 |
|
216 |
| - This means that the 'policy' type will contain the observations "group_1" and "group_2" and the 'critic' type will |
217 |
| - contain the observations "group_1" and "group_3". The function will check that all the observations in the 'policy' |
218 |
| - and 'critic' groups are present in the observation dictionary from the environment. |
| 216 | + This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the |
| 217 | + 'critic' observation set will contain the observations "group_1" and "group_3". This function will check that all |
| 218 | + the observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the |
| 219 | + environment. |
219 | 220 |
|
220 |
| - Additionally, if one of the `default_types`, e.g. "critic", is not present in the configuration dictionary, |
| 221 | + Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary, |
221 | 222 | this function will:
|
222 |
| - 1. Check if a group with the same name exists in the observations and assign this group to the observation type. |
223 |
| - 2. If not, it will assign the observations from the 'policy' type to the observation type. |
| 223 | +
|
| 224 | + 1. Check if a group with the same name exists in the observations and assign this group to the observation set. |
| 225 | + 2. If 1. fails, it will assign the observations from the 'policy' observation set to the default observation set. |
224 | 226 |
|
225 | 227 | Args:
|
226 | 228 | obs: Observations from the environment in the form of a dictionary.
|
227 |
| - obs_groups: Observation types configuration. |
228 |
| - default_types: Reserved type names used by the algorithm (besides 'policy'). |
| 229 | + obs_groups: Observation sets configuration. |
| 230 | + default_sets: Reserved observation set names used by the algorithm (besides 'policy'). |
229 | 231 | If not provided in 'obs_groups', a default behavior gets triggered.
|
230 | 232 |
|
231 | 233 | Returns:
|
232 | 234 | The resolved observation groups.
|
233 | 235 |
|
234 | 236 | Raises:
|
235 |
| - ValueError: If the "policy" observation type is not present in the provided observation groups configuration. |
236 |
| - ValueError: If any observation type is an empty list. |
237 |
| - ValueError: If any observation type contains an observation term that is not present in the observations. |
| 237 | + ValueError: If any observation set is an empty list. |
| 238 | + ValueError: If any observation set contains an observation term that is not present in the observations. |
238 | 239 | """
|
239 |
| - # check if policy observation type exists |
| 240 | + # check if policy observation set exists |
240 | 241 | if "policy" not in obs_groups.keys():
|
241 |
| - raise ValueError( |
242 |
| - "The observation type configuration dictionary must contain the 'policy' key." |
243 |
| - f" Found keys: {list(obs_groups.keys())}" |
244 |
| - ) |
| 242 | + if "policy" in obs: |
| 243 | + obs_groups["policy"] = ["policy"] |
| 244 | + warnings.warn( |
| 245 | + "The observation configuration dictionary 'obs_groups' must contain the 'policy' key." |
| 246 | + " As an observation group with the name 'policy' was found, this is assumed to be the observation set." |
| 247 | + " Consider adding the 'policy' key to the 'obs_groups' dictionary for clarity." |
| 248 | + " This behavior will be removed in a future version." |
| 249 | + ) |
| 250 | + else: |
| 251 | + raise ValueError( |
| 252 | + "The observation configuration dictionary 'obs_groups' must contain the 'policy' key." |
| 253 | + f" Found keys: {list(obs_groups.keys())}" |
| 254 | + ) |
245 | 255 |
|
246 |
| - # check all observation types for valid observation groups |
247 |
| - for type, groups in obs_groups.items(): |
| 256 | + # check all observation sets for valid observation groups |
| 257 | + for set_name, groups in obs_groups.items(): |
248 | 258 | # check if the list is empty
|
249 | 259 | if len(groups) == 0:
|
250 |
| - msg = f"The '{type}' key in the 'obs_groups' dictionary can not be an empty list." |
251 |
| - if type in default_types: |
252 |
| - if type not in obs: |
253 |
| - msg += " Consider removing the key to default to the observations used for the 'policy' type." |
| 260 | + msg = f"The '{set_name}' key in the 'obs_groups' dictionary can not be an empty list." |
| 261 | + if set_name in default_sets: |
| 262 | + if set_name not in obs: |
| 263 | + msg += " Consider removing the key to default to the observations used for the 'policy' set." |
254 | 264 | else:
|
255 |
| - msg += f" Consider removing the key to default to the observation '{type}' from the environment." |
| 265 | + msg += ( |
| 266 | + f" Consider removing the key to default to the observation '{set_name}' from the environment." |
| 267 | + ) |
256 | 268 | raise ValueError(msg)
|
257 | 269 | # check groups exist inside the observations from the environment
|
258 | 270 | for group in groups:
|
259 | 271 | if group not in obs:
|
260 | 272 | raise ValueError(
|
261 |
| - f"Observation '{group}' in observation type '{type}' not found in the observations from the" |
| 273 | + f"Observation '{group}' in observation set '{set_name}' not found in the observations from the" |
262 | 274 | f" environment. Available observations from the environment: {list(obs.keys())}"
|
263 | 275 | )
|
264 | 276 |
|
265 |
| - # fill missing observation types |
266 |
| - for default_type in default_types: |
267 |
| - if default_type not in obs_groups.keys(): |
268 |
| - if default_type in obs: |
269 |
| - obs_groups[default_type] = [default_type] |
| 277 | + # fill missing observation sets |
| 278 | + for default_set_name in default_sets: |
| 279 | + if default_set_name not in obs_groups.keys(): |
| 280 | + if default_set_name in obs: |
| 281 | + obs_groups[default_set_name] = [default_set_name] |
| 282 | + warnings.warn( |
| 283 | + f"The observation configuration dictionary 'obs_groups' must contain the '{default_set_name}' key." |
| 284 | + f" As an observation group with the name '{default_set_name}' was found, this is assumed to be the" |
| 285 | + f" observation set. Consider adding the '{default_set_name}' key to the 'obs_groups' dictionary for" |
| 286 | + " clarity. This behavior will be removed in a future version." |
| 287 | + ) |
270 | 288 | else:
|
271 |
| - obs_groups[default_type] = obs_groups["policy"].copy() |
| 289 | + obs_groups[default_set_name] = obs_groups["policy"].copy() |
| 290 | + warnings.warn( |
| 291 | + f"The observation configuration dictionary 'obs_groups' must contain the '{default_set_name}' key." |
| 292 | + f" As the configuration for '{default_set_name}' is missing, the observations from the 'policy' set" |
| 293 | + f" are used. Consider adding the '{default_set_name}' key to the 'obs_groups' dictionary for" |
| 294 | + " clarity. This behavior will be removed in a future version." |
| 295 | + ) |
272 | 296 |
|
273 |
| - # print the final parsed observation types |
274 |
| - print("Resolved observation types: ") |
275 |
| - for type, groups in obs_groups.items(): |
276 |
| - print("\t", type, ": ", groups) |
| 297 | + # print the final parsed observation sets |
| 298 | + print("-" * 80) |
| 299 | + print("Resolved observation sets: ") |
| 300 | + for set_name, groups in obs_groups.items(): |
| 301 | + print("\t", set_name, ": ", groups) |
| 302 | + print("-" * 80) |
277 | 303 |
|
278 | 304 | return obs_groups
|
0 commit comments