|
| 1 | +import collections |
| 2 | +import logging |
| 3 | +from functools import partial |
| 4 | + |
| 5 | +from termcolor import colored |
| 6 | + |
| 7 | +from .helpers import add |
| 8 | + |
| 9 | + |
| 10 | +class DiffMount(type): |
| 11 | + """Metaclass for Diff plugin system""" |
| 12 | + # noinspection PyUnusedLocal,PyMissingConstructor |
| 13 | + def __init__(cls, *args, **kwargs): |
| 14 | + if not hasattr(cls, 'plugins'): |
| 15 | + cls.plugins = dict() |
| 16 | + else: |
| 17 | + cls.plugins[cls.__name__] = cls |
| 18 | + |
| 19 | + |
| 20 | +class DiffBase(metaclass=DiffMount): |
| 21 | + """Superclass for diff plugins""" |
| 22 | + def __init__(self, remote, local): |
| 23 | + self.logger = logging.getLogger(self.__module__) |
| 24 | + self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local) |
| 25 | + self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys()) |
| 26 | + |
| 27 | + # noinspection PyUnusedLocal |
| 28 | + @classmethod |
| 29 | + def get_plugin(cls, name): |
| 30 | + if name in cls.plugins: |
| 31 | + return cls.plugins[name]() |
| 32 | + |
| 33 | + @classmethod |
| 34 | + def configure(cls, args): |
| 35 | + """Extract class-specific configurations from CLI args and pre-configure the __init__ method using functools.partial""" |
| 36 | + return cls |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def _flatten(cls, d, current_path='', sep='/'): |
| 40 | + """Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}""" |
| 41 | + items = [] |
| 42 | + for k in d: |
| 43 | + new = current_path + sep + k if current_path else k |
| 44 | + if isinstance(d[k], collections.MutableMapping): |
| 45 | + items.extend(cls._flatten(d[k], new, sep=sep).items()) |
| 46 | + else: |
| 47 | + items.append((sep + new, d[k])) |
| 48 | + return dict(items) |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def _unflatten(cls, d, sep='/'): |
| 52 | + """Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure""" |
| 53 | + output = {} |
| 54 | + for k in d: |
| 55 | + add( |
| 56 | + obj=output, |
| 57 | + path=k, |
| 58 | + value=d[k], |
| 59 | + sep=sep, |
| 60 | + ) |
| 61 | + return output |
| 62 | + |
| 63 | + @classmethod |
| 64 | + def describe_diff(cls, plan): |
| 65 | + """Return a (multi-line) string describing all differences""" |
| 66 | + description = "" |
| 67 | + for k, v in plan['add'].items(): |
| 68 | + # { key: new_value } |
| 69 | + description += colored("+", 'green'), "{} = {}".format(k, v) + '\n' |
| 70 | + |
| 71 | + for k in plan['delete']: |
| 72 | + # { key: old_value } |
| 73 | + description += colored("-", 'red'), k + '\n' |
| 74 | + |
| 75 | + for k, v in plan['change'].items(): |
| 76 | + # { key: {'old': value, 'new': value} } |
| 77 | + description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, v['old'], v['new']) + '\n' |
| 78 | + |
| 79 | + return description |
| 80 | + |
| 81 | + @property |
| 82 | + def plan(self): |
| 83 | + """Returns a `dict` of operations for updating the remote storage i.e. {'add': {...}, 'change': {...}, 'delete': {...}}""" |
| 84 | + raise NotImplementedError |
| 85 | + |
| 86 | + def merge(self): |
| 87 | + """Generate a merge of the local and remote dicts, following configurations set during __init__""" |
| 88 | + raise NotImplementedError |
| 89 | + |
| 90 | + |
| 91 | +class DiffResolver(DiffBase): |
| 92 | + """Determines diffs between two dicts, where the remote copy is considered the baseline""" |
| 93 | + def __init__(self, remote, local, force=False): |
| 94 | + super().__init__(remote, local) |
| 95 | + self.intersection = self.remote_set.intersection(self.local_set) |
| 96 | + self.force = force |
| 97 | + |
| 98 | + if self.added() or self.removed() or self.changed(): |
| 99 | + self.differ = True |
| 100 | + else: |
| 101 | + self.differ = False |
| 102 | + |
| 103 | + @classmethod |
| 104 | + def configure(cls, args): |
| 105 | + return partial(cls, force=args.diffresolver_force) |
| 106 | + |
| 107 | + def added(self): |
| 108 | + """Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}""" |
| 109 | + return self.local_set - self.intersection |
| 110 | + |
| 111 | + def removed(self): |
| 112 | + """Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}""" |
| 113 | + return self.remote_set - self.intersection |
| 114 | + |
| 115 | + def changed(self): |
| 116 | + """Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}""" |
| 117 | + return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k]) |
| 118 | + |
| 119 | + def unchanged(self): |
| 120 | + """Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}""" |
| 121 | + return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k]) |
| 122 | + |
| 123 | + @property |
| 124 | + def plan(self): |
| 125 | + return { |
| 126 | + 'add': { |
| 127 | + k: self.local_flat[k] for k in self.added() |
| 128 | + }, |
| 129 | + 'delete': { |
| 130 | + k: self.remote_flat[k] for k in self.removed() |
| 131 | + }, |
| 132 | + 'change': { |
| 133 | + k: {'old': self.remote_flat[k], 'new': self.local_flat[k]} for k in self.changed() |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + def merge(self): |
| 138 | + dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)]) |
| 139 | + if self.force: |
| 140 | + # Overwrite local changes (i.e. only preserve added keys) |
| 141 | + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add |
| 142 | + prior_set = self.changed().union(self.removed()).union(self.unchanged()) |
| 143 | + current_set = self.added() |
| 144 | + else: |
| 145 | + # Preserve added keys and changed keys |
| 146 | + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add |
| 147 | + prior_set = self.unchanged().union(self.removed()) |
| 148 | + current_set = self.added().union(self.changed()) |
| 149 | + state = dictfilter(original=self.remote_flat, keep_keys=prior_set) |
| 150 | + state.update(dictfilter(original=self.local_flat, keep_keys=current_set)) |
| 151 | + return self._unflatten(state) |
0 commit comments