|
55 | 55 | import sys |
56 | 56 | from argparse import ArgumentParser |
57 | 57 | from collections import defaultdict |
| 58 | +from copy import deepcopy |
58 | 59 | from dataclasses import dataclass, field |
59 | 60 | from itertools import chain |
60 | 61 | from pathlib import Path |
@@ -321,37 +322,42 @@ def flush_buffers() -> None: |
321 | 322 | sys.stderr.flush() |
322 | 323 |
|
323 | 324 |
|
324 | | -def merge_into(dest: Any, new: Any) -> None: |
| 325 | +def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: |
325 | 326 | """ |
326 | | - Merges `new` into `dest` with the following rules: |
| 327 | + Merges `a` and `b` together, returning the result. |
| 328 | +
|
| 329 | + The merge is performed with the following rules: |
327 | 330 |
|
328 | 331 | - dicts: values with the same key will be merged recursively |
329 | 332 | - lists: `new` will be appended to `dest` |
330 | 333 | - primitives: they will be checked for equality and inequality will result |
331 | 334 | in a ValueError |
332 | 335 |
|
333 | | - It is an error for `dest` and `new` to be of different types. |
334 | | - """ |
335 | | - if isinstance(dest, dict) and isinstance(new, dict): |
336 | | - for k, v in new.items(): |
337 | | - if k in dest: |
338 | | - merge_into(dest[k], v) |
339 | | - else: |
340 | | - dest[k] = v |
341 | | - elif isinstance(dest, list) and isinstance(new, list): |
342 | | - dest.extend(new) |
343 | | - elif type(dest) != type(new): |
344 | | - raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}") |
345 | | - elif dest != new: |
346 | | - raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}") |
347 | | - |
348 | 336 |
|
349 | | -def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: |
350 | | - """ |
351 | | - Merges `b` into `a` and returns `a`. Here because we can't use `merge_into` |
352 | | - in a lamba conveniently. |
| 337 | + It is an error for `a` and `b` to be of different types. |
353 | 338 | """ |
354 | | - merge_into(a, b) |
| 339 | + if isinstance(a, dict) and isinstance(b, dict): |
| 340 | + result = {} |
| 341 | + for key in set(a.keys()) | set(b.keys): |
| 342 | + if key in a and key in b: |
| 343 | + result[key] = merged(a[key], b[key]) |
| 344 | + elif key in a: |
| 345 | + result[key] = deepcopy(a[key]) |
| 346 | + else: |
| 347 | + result[key] = deepcopy(b[key]) |
| 348 | + |
| 349 | + return result |
| 350 | + elif isinstance(a, list) and isinstance(b, list): |
| 351 | + return deepcopy(a) + deepcopy(b) |
| 352 | + elif type(a) != type(b): |
| 353 | + raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}") |
| 354 | + elif a != b: |
| 355 | + raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}") |
| 356 | + |
| 357 | + if type(a) not in {str, int, float, bool, None.__class__}: |
| 358 | + raise TypeError( |
| 359 | + f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)" |
| 360 | + ) |
355 | 361 | return a |
356 | 362 |
|
357 | 363 |
|
@@ -454,10 +460,10 @@ def instantiate_worker_template( |
454 | 460 | Returns: worker configuration dictionary |
455 | 461 | """ |
456 | 462 | worker_config_dict = dataclasses.asdict(template) |
457 | | - stream_writers_dict = { |
458 | | - writer: worker_name for writer in template.stream_writers |
459 | | - } |
460 | | - worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict) |
| 463 | + stream_writers_dict = {writer: worker_name for writer in template.stream_writers} |
| 464 | + worker_config_dict["shared_extra_conf"] = merged( |
| 465 | + template.shared_extra_conf(worker_name), stream_writers_dict |
| 466 | + ) |
461 | 467 | worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns) |
462 | 468 | worker_config_dict["listener_resources"] = sorted(template.listener_resources) |
463 | 469 | return worker_config_dict |
@@ -786,7 +792,7 @@ def generate_worker_files( |
786 | 792 | ) |
787 | 793 |
|
788 | 794 | # Update the shared config with any options needed to enable this worker. |
789 | | - merge_into(shared_config, worker_config["shared_extra_conf"]) |
| 795 | + shared_config = merged(shared_config, worker_config["shared_extra_conf"]) |
790 | 796 |
|
791 | 797 | if using_unix_sockets: |
792 | 798 | healthcheck_urls.append( |
|
0 commit comments