|
14 | 14 |
|
15 | 15 | import array |
16 | 16 | from enum import Enum |
| 17 | +from typing import Dict |
| 18 | +from typing import List |
| 19 | +from typing import Optional |
17 | 20 |
|
18 | 21 | from rcl_interfaces.msg import Parameter as ParameterMsg |
19 | | -from rcl_interfaces.msg import ParameterType, ParameterValue |
| 22 | +from rcl_interfaces.msg import ParameterType |
| 23 | +from rcl_interfaces.msg import ParameterValue |
| 24 | +import yaml |
20 | 25 |
|
21 | 26 | PARAMETER_SEPARATOR_STRING = '.' |
22 | 27 |
|
@@ -177,6 +182,50 @@ def to_parameter_msg(self): |
177 | 182 | return ParameterMsg(name=self.name, value=self.get_parameter_value()) |
178 | 183 |
|
179 | 184 |
|
| 185 | +def get_parameter_value(string_value: str) -> ParameterValue: |
| 186 | + """ |
| 187 | + Guess the desired type of the parameter based on the string value. |
| 188 | +
|
| 189 | + :param string_value: The string value to be converted to a ParameterValue. |
| 190 | + :return: The ParameterValue. |
| 191 | + """ |
| 192 | + value = ParameterValue() |
| 193 | + try: |
| 194 | + yaml_value = yaml.safe_load(string_value) |
| 195 | + except yaml.parser.ParserError: |
| 196 | + yaml_value = string_value |
| 197 | + |
| 198 | + if isinstance(yaml_value, bool): |
| 199 | + value.type = ParameterType.PARAMETER_BOOL |
| 200 | + value.bool_value = yaml_value |
| 201 | + elif isinstance(yaml_value, int): |
| 202 | + value.type = ParameterType.PARAMETER_INTEGER |
| 203 | + value.integer_value = yaml_value |
| 204 | + elif isinstance(yaml_value, float): |
| 205 | + value.type = ParameterType.PARAMETER_DOUBLE |
| 206 | + value.double_value = yaml_value |
| 207 | + elif isinstance(yaml_value, list): |
| 208 | + if all((isinstance(v, bool) for v in yaml_value)): |
| 209 | + value.type = ParameterType.PARAMETER_BOOL_ARRAY |
| 210 | + value.bool_array_value = yaml_value |
| 211 | + elif all((isinstance(v, int) for v in yaml_value)): |
| 212 | + value.type = ParameterType.PARAMETER_INTEGER_ARRAY |
| 213 | + value.integer_array_value = yaml_value |
| 214 | + elif all((isinstance(v, float) for v in yaml_value)): |
| 215 | + value.type = ParameterType.PARAMETER_DOUBLE_ARRAY |
| 216 | + value.double_array_value = yaml_value |
| 217 | + elif all((isinstance(v, str) for v in yaml_value)): |
| 218 | + value.type = ParameterType.PARAMETER_STRING_ARRAY |
| 219 | + value.string_array_value = yaml_value |
| 220 | + else: |
| 221 | + value.type = ParameterType.PARAMETER_STRING |
| 222 | + value.string_value = string_value |
| 223 | + else: |
| 224 | + value.type = ParameterType.PARAMETER_STRING |
| 225 | + value.string_value = yaml_value |
| 226 | + return value |
| 227 | + |
| 228 | + |
180 | 229 | def parameter_value_to_python(parameter_value: ParameterValue): |
181 | 230 | """ |
182 | 231 | Get the value for the Python builtin type from a rcl_interfaces/msg/ParameterValue object. |
@@ -211,3 +260,79 @@ def parameter_value_to_python(parameter_value: ParameterValue): |
211 | 260 | raise RuntimeError(f'unexpected parameter type {parameter_value.type}') |
212 | 261 |
|
213 | 262 | return value |
| 263 | + |
| 264 | + |
| 265 | +def parameter_dict_from_yaml_file( |
| 266 | + parameter_file: str, |
| 267 | + use_wildcard: bool = False, |
| 268 | + target_nodes: Optional[List[str]] = None, |
| 269 | + namespace: str = '' |
| 270 | +) -> Dict[str, ParameterMsg]: |
| 271 | + """ |
| 272 | + Build a dict of parameters from a YAML file. |
| 273 | +
|
| 274 | + Will load all parameters if ``target_nodes`` is None or empty. |
| 275 | +
|
| 276 | + :raises RuntimeError: if a target node is not in the file |
| 277 | + :raises RuntimeError: if the is not a valid ROS parameter file |
| 278 | +
|
| 279 | + :param parameter_file: Path to the YAML file to load parameters from. |
| 280 | + :param use_wildcard: Use wildcard matching for the target nodes. |
| 281 | + :param target_nodes: List of nodes in the YAML file to load parameters from. |
| 282 | + :param namespace: Namespace to prepend to all parameters. |
| 283 | + :return: A dict of Parameter messages keyed by the parameter names |
| 284 | + """ |
| 285 | + with open(parameter_file, 'r') as f: |
| 286 | + param_file = yaml.safe_load(f) |
| 287 | + param_keys = [] |
| 288 | + param_dict = {} |
| 289 | + |
| 290 | + if use_wildcard and '/**' in param_file: |
| 291 | + param_keys.append('/**') |
| 292 | + |
| 293 | + if target_nodes: |
| 294 | + for n in target_nodes: |
| 295 | + if n not in param_file.keys(): |
| 296 | + raise RuntimeError(f'Param file does not contain parameters for {n},' |
| 297 | + f'only for nodes: {list(param_file.keys())} ') |
| 298 | + param_keys.append(n) |
| 299 | + else: |
| 300 | + # wildcard key must go to the front of param_keys so that |
| 301 | + # node-namespaced parameters will override the wildcard parameters |
| 302 | + keys = set(param_file.keys()) |
| 303 | + keys.discard('/**') |
| 304 | + param_keys.extend(keys) |
| 305 | + |
| 306 | + if len(param_keys) == 0: |
| 307 | + raise RuntimeError('Param file does not contain selected parameters') |
| 308 | + |
| 309 | + for n in param_keys: |
| 310 | + value = param_file[n] |
| 311 | + if type(value) != dict or 'ros__parameters' not in value: |
| 312 | + raise RuntimeError(f'YAML file is not a valid ROS parameter file for node {n}') |
| 313 | + param_dict.update(value['ros__parameters']) |
| 314 | + return _unpack_parameter_dict(namespace, param_dict) |
| 315 | + |
| 316 | + |
| 317 | +def _unpack_parameter_dict(namespace, parameter_dict): |
| 318 | + """ |
| 319 | + Flatten a parameter dictionary recursively. |
| 320 | +
|
| 321 | + :param namespace: The namespace to prepend to the parameter names. |
| 322 | + :param parameter_dict: A dictionary of parameters keyed by the parameter names |
| 323 | + :return: A dict of Parameter objects keyed by the parameter names |
| 324 | + """ |
| 325 | + parameters: Dict[str, ParameterMsg] = {} |
| 326 | + for param_name, param_value in parameter_dict.items(): |
| 327 | + full_param_name = namespace + param_name |
| 328 | + # Unroll nested parameters |
| 329 | + if type(param_value) == dict: |
| 330 | + parameters.update(_unpack_parameter_dict( |
| 331 | + namespace=full_param_name + PARAMETER_SEPARATOR_STRING, |
| 332 | + parameter_dict=param_value)) |
| 333 | + else: |
| 334 | + parameter = ParameterMsg() |
| 335 | + parameter.name = full_param_name |
| 336 | + parameter.value = get_parameter_value(str(param_value)) |
| 337 | + parameters[full_param_name] = parameter |
| 338 | + return parameters |
0 commit comments