Skip to content

Commit 8c71575

Browse files
committed
Remove merge_into and just have merged which copies inputs to avoid footguns
1 parent 29541fd commit 8c71575

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

docker/configure_workers_and_start.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import sys
5656
from argparse import ArgumentParser
5757
from collections import defaultdict
58+
from copy import deepcopy
5859
from dataclasses import dataclass, field
5960
from itertools import chain
6061
from pathlib import Path
@@ -321,37 +322,42 @@ def flush_buffers() -> None:
321322
sys.stderr.flush()
322323

323324

324-
def merge_into(dest: Any, new: Any) -> None:
325+
def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
325326
"""
326-
Merges `new` into `dest` with the following rules:
327+
Merges `a` and `b` together, returning the result.
328+
329+
The merge is performed with the following rules:
327330
328331
- dicts: values with the same key will be merged recursively
329332
- lists: `new` will be appended to `dest`
330333
- primitives: they will be checked for equality and inequality will result
331334
in a ValueError
332335
333-
It is an error for `dest` and `new` to be of different types.
334-
"""
335-
if isinstance(dest, dict) and isinstance(new, dict):
336-
for k, v in new.items():
337-
if k in dest:
338-
merge_into(dest[k], v)
339-
else:
340-
dest[k] = v
341-
elif isinstance(dest, list) and isinstance(new, list):
342-
dest.extend(new)
343-
elif type(dest) != type(new):
344-
raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}")
345-
elif dest != new:
346-
raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}")
347-
348336
349-
def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
350-
"""
351-
Merges `b` into `a` and returns `a`. Here because we can't use `merge_into`
352-
in a lamba conveniently.
337+
It is an error for `a` and `b` to be of different types.
353338
"""
354-
merge_into(a, b)
339+
if isinstance(a, dict) and isinstance(b, dict):
340+
result = {}
341+
for key in set(a.keys()) | set(b.keys):
342+
if key in a and key in b:
343+
result[key] = merged(a[key], b[key])
344+
elif key in a:
345+
result[key] = deepcopy(a[key])
346+
else:
347+
result[key] = deepcopy(b[key])
348+
349+
return result
350+
elif isinstance(a, list) and isinstance(b, list):
351+
return deepcopy(a) + deepcopy(b)
352+
elif type(a) != type(b):
353+
raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}")
354+
elif a != b:
355+
raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}")
356+
357+
if type(a) not in {str, int, float, bool, None.__class__}:
358+
raise TypeError(
359+
f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)"
360+
)
355361
return a
356362

357363

@@ -454,10 +460,10 @@ def instantiate_worker_template(
454460
Returns: worker configuration dictionary
455461
"""
456462
worker_config_dict = dataclasses.asdict(template)
457-
stream_writers_dict = {
458-
writer: worker_name for writer in template.stream_writers
459-
}
460-
worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict)
463+
stream_writers_dict = {writer: worker_name for writer in template.stream_writers}
464+
worker_config_dict["shared_extra_conf"] = merged(
465+
template.shared_extra_conf(worker_name), stream_writers_dict
466+
)
461467
worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns)
462468
worker_config_dict["listener_resources"] = sorted(template.listener_resources)
463469
return worker_config_dict
@@ -786,7 +792,7 @@ def generate_worker_files(
786792
)
787793

788794
# Update the shared config with any options needed to enable this worker.
789-
merge_into(shared_config, worker_config["shared_extra_conf"])
795+
shared_config = merged(shared_config, worker_config["shared_extra_conf"])
790796

791797
if using_unix_sockets:
792798
healthcheck_urls.append(

0 commit comments

Comments
 (0)