|
| 1 | +import functools |
| 2 | +import inspect |
| 3 | +import itertools |
| 4 | +from inspect import Parameter |
| 5 | + |
| 6 | +import pint |
| 7 | +import pint.testing |
| 8 | +import xarray as xr |
| 9 | + |
| 10 | +from pint_xarray.accessors import get_registry |
| 11 | +from pint_xarray.conversion import extract_units |
| 12 | +from pint_xarray.itertools import zip_mappings |
| 13 | + |
| 14 | +variable_parameters = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) |
| 15 | + |
| 16 | + |
| 17 | +def _number_of_results(result): |
| 18 | + if isinstance(result, tuple): |
| 19 | + return len(result) |
| 20 | + elif result is None: |
| 21 | + return 0 |
| 22 | + else: |
| 23 | + return 1 |
| 24 | + |
| 25 | + |
| 26 | +def expects(*args_units, return_value=None, **kwargs_units): |
| 27 | + """ |
| 28 | + Decorator which ensures the inputs and outputs of the decorated |
| 29 | + function are expressed in the expected units. |
| 30 | +
|
| 31 | + Arguments to the decorated function are checked for the specified |
| 32 | + units, converting to those units if necessary, and then stripped |
| 33 | + of their units before being passed into the undecorated |
| 34 | + function. Therefore the undecorated function should expect |
| 35 | + unquantified DataArrays, Datasets, or numpy-like arrays, but with |
| 36 | + the values expressed in specific units. |
| 37 | +
|
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + func : callable |
| 41 | + Function to decorate, which accepts zero or more |
| 42 | + xarray.DataArrays or numpy-like arrays as inputs, and may |
| 43 | + optionally return one or more xarray.DataArrays or numpy-like |
| 44 | + arrays. |
| 45 | + *args_units : unit-like or mapping of hashable to unit-like, optional |
| 46 | + Units to expect for each positional argument given to func. |
| 47 | +
|
| 48 | + The decorator will first check that arguments passed to the |
| 49 | + decorated function possess these specific units (or will |
| 50 | + attempt to convert the argument to these units), then will |
| 51 | + strip the units before passing the magnitude to the wrapped |
| 52 | + function. |
| 53 | +
|
| 54 | + A value of None indicates not to check that argument for units |
| 55 | + (suitable for flags and other non-data arguments). |
| 56 | + return_value : unit-like or list of unit-like or mapping of hashable to unit-like \ |
| 57 | + or list of mapping of hashable to unit-like, optional |
| 58 | + The expected units of the returned value(s), either as a |
| 59 | + single unit or as a list of units. The decorator will attach |
| 60 | + these units to the variables returned from the function. |
| 61 | +
|
| 62 | + A value of None indicates not to attach any units to that |
| 63 | + return value (suitable for flags and other non-data results). |
| 64 | + **kwargs_units : mapping of hashable to unit-like, optional |
| 65 | + Unit to expect for each keyword argument given to func. |
| 66 | +
|
| 67 | + The decorator will first check that arguments passed to the decorated |
| 68 | + function possess these specific units (or will attempt to convert the |
| 69 | + argument to these units), then will strip the units before passing the |
| 70 | + magnitude to the wrapped function. |
| 71 | +
|
| 72 | + A value of None indicates not to check that argument for units (suitable |
| 73 | + for flags and other non-data arguments). |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + return_values : Any |
| 78 | + Return values of the wrapped function, either a single value or a tuple |
| 79 | + of values. These will be given units according to ``return_value``. |
| 80 | +
|
| 81 | + Raises |
| 82 | + ------ |
| 83 | + TypeError |
| 84 | + If any of the units are not a valid type. |
| 85 | + ValueError |
| 86 | + If the number of arguments or return values does not match the number of |
| 87 | + units specified. Also thrown if any parameter does not have a unit |
| 88 | + specified. |
| 89 | +
|
| 90 | + See Also |
| 91 | + -------- |
| 92 | + pint.wraps |
| 93 | +
|
| 94 | + Examples |
| 95 | + -------- |
| 96 | + Decorating a function which takes one quantified input, but |
| 97 | + returns a non-data value (in this case a boolean). |
| 98 | +
|
| 99 | + >>> @expects("deg C") |
| 100 | + ... def above_freezing(temp): |
| 101 | + ... return temp > 0 |
| 102 | + ... |
| 103 | +
|
| 104 | + Decorating a function which allows any dimensions for the array, but also |
| 105 | + accepts an optional `weights` keyword argument, which must be dimensionless. |
| 106 | +
|
| 107 | + >>> @expects(None, weights="dimensionless") |
| 108 | + ... def mean(da, weights=None): |
| 109 | + ... if weights: |
| 110 | + ... return da.weighted(weights=weights).mean() |
| 111 | + ... else: |
| 112 | + ... return da.mean() |
| 113 | + ... |
| 114 | + """ |
| 115 | + |
| 116 | + def outer(func): |
| 117 | + signature = inspect.signature(func) |
| 118 | + |
| 119 | + params_units = signature.bind(*args_units, **kwargs_units) |
| 120 | + |
| 121 | + missing_params = [ |
| 122 | + name |
| 123 | + for name, p in signature.parameters.items() |
| 124 | + if p.kind not in variable_parameters and name not in params_units.arguments |
| 125 | + ] |
| 126 | + if missing_params: |
| 127 | + raise ValueError( |
| 128 | + "Missing units for the following parameters: " |
| 129 | + + ", ".join(map(repr, missing_params)) |
| 130 | + ) |
| 131 | + |
| 132 | + n_expected_results = _number_of_results(return_value) |
| 133 | + |
| 134 | + @functools.wraps(func) |
| 135 | + def wrapper(*args, **kwargs): |
| 136 | + nonlocal return_value |
| 137 | + |
| 138 | + params = signature.bind(*args, **kwargs) |
| 139 | + # don't apply defaults, as those can't be quantities and thus must |
| 140 | + # already be in the correct units |
| 141 | + |
| 142 | + spec_units = dict( |
| 143 | + enumerate( |
| 144 | + itertools.chain.from_iterable( |
| 145 | + spec.values() if isinstance(spec, dict) else (spec,) |
| 146 | + for spec in params_units.arguments.values() |
| 147 | + if spec is not None |
| 148 | + ) |
| 149 | + ) |
| 150 | + ) |
| 151 | + params_units_ = dict( |
| 152 | + enumerate( |
| 153 | + itertools.chain.from_iterable( |
| 154 | + ( |
| 155 | + extract_units(param) |
| 156 | + if isinstance(param, (xr.DataArray, xr.Dataset)) |
| 157 | + else (param.units,) |
| 158 | + ) |
| 159 | + for name, param in params.arguments.items() |
| 160 | + if isinstance(param, (xr.DataArray, xr.Dataset, pint.Quantity)) |
| 161 | + ) |
| 162 | + ) |
| 163 | + ) |
| 164 | + |
| 165 | + ureg = get_registry( |
| 166 | + None, |
| 167 | + dict(spec_units) if spec_units else {}, |
| 168 | + dict(params_units_) if params_units else {}, |
| 169 | + ) |
| 170 | + |
| 171 | + errors = [] |
| 172 | + for name, (value, units) in zip_mappings( |
| 173 | + params.arguments, params_units.arguments |
| 174 | + ): |
| 175 | + try: |
| 176 | + if units is None: |
| 177 | + if isinstance(value, pint.Quantity) or ( |
| 178 | + isinstance(value, (xr.DataArray, xr.Dataset)) |
| 179 | + and value.pint.units |
| 180 | + ): |
| 181 | + raise TypeError( |
| 182 | + "Passed in a quantity where none was expected" |
| 183 | + ) |
| 184 | + continue |
| 185 | + if isinstance(value, pint.Quantity): |
| 186 | + params.arguments[name] = value.m_as(units) |
| 187 | + elif isinstance(value, (xr.DataArray, xr.Dataset)): |
| 188 | + params.arguments[name] = value.pint.to(units).pint.dequantify() |
| 189 | + else: |
| 190 | + raise TypeError( |
| 191 | + f"Attempting to convert non-quantity {value} to {units}." |
| 192 | + ) |
| 193 | + except ( |
| 194 | + TypeError, |
| 195 | + pint.errors.UndefinedUnitError, |
| 196 | + pint.errors.DimensionalityError, |
| 197 | + ) as e: |
| 198 | + e.add_note( |
| 199 | + f"expects: raised while trying to convert parameter {name}" |
| 200 | + ) |
| 201 | + errors.append(e) |
| 202 | + |
| 203 | + if errors: |
| 204 | + raise ExceptionGroup("Errors while converting parameters", errors) |
| 205 | + |
| 206 | + result = func(*params.args, **params.kwargs) |
| 207 | + |
| 208 | + n_results = _number_of_results(result) |
| 209 | + if return_value is not None and ( |
| 210 | + (isinstance(result, tuple) ^ isinstance(return_value, tuple)) |
| 211 | + or (n_results != n_expected_results) |
| 212 | + ): |
| 213 | + message = "mismatched number of return values:" |
| 214 | + if n_results != n_expected_results: |
| 215 | + message += f" expected {n_expected_results} but got {n_results}." |
| 216 | + elif isinstance(result, tuple) and not isinstance(return_value, tuple): |
| 217 | + message += ( |
| 218 | + " expected a single return value but got a 1-sized tuple." |
| 219 | + ) |
| 220 | + else: |
| 221 | + message += ( |
| 222 | + " expected a 1-sized tuple but got a single return value." |
| 223 | + ) |
| 224 | + raise ValueError(message) |
| 225 | + |
| 226 | + if result is None: |
| 227 | + return |
| 228 | + |
| 229 | + if not isinstance(result, tuple): |
| 230 | + result = (result,) |
| 231 | + if not isinstance(return_value, tuple): |
| 232 | + return_value = (return_value,) |
| 233 | + |
| 234 | + final_result = [] |
| 235 | + errors = [] |
| 236 | + for index, (value, units) in enumerate(zip(result, return_value)): |
| 237 | + if units is not None: |
| 238 | + try: |
| 239 | + if isinstance(value, (xr.Dataset, xr.DataArray)): |
| 240 | + value = value.pint.quantify(units) |
| 241 | + else: |
| 242 | + value = ureg.Quantity(value, units) |
| 243 | + except Exception as e: |
| 244 | + e.add_note( |
| 245 | + f"expects: raised while trying to convert return value {index}" |
| 246 | + ) |
| 247 | + errors.append(e) |
| 248 | + |
| 249 | + final_result.append(value) |
| 250 | + |
| 251 | + if errors: |
| 252 | + raise ExceptionGroup("Errors while converting return values", errors) |
| 253 | + |
| 254 | + if n_results == 1: |
| 255 | + return final_result[0] |
| 256 | + return tuple(final_result) |
| 257 | + |
| 258 | + return wrapper |
| 259 | + |
| 260 | + return outer |
0 commit comments