Skip to content

expects decorator #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
60230d8
add `separate`
keewis Aug 9, 2025
2da7313
function to get unique values while order-preserving
keewis Aug 9, 2025
04afb75
util to zip mappings
keewis Aug 9, 2025
f5dadd5
reimplement `zip_mappings` to be more robust
keewis Aug 9, 2025
cfcab93
implement `expects`
keewis Aug 9, 2025
1bcc6c7
test that expects correctly converts args
keewis Aug 9, 2025
1b53c7c
check that default args work, too
keewis Aug 10, 2025
09525d9
support checking for single errors, as well
keewis Aug 10, 2025
6832739
check that units in kwargs work
keewis Aug 10, 2025
3778c2b
raise an error for all parameters without unit spec
keewis Aug 10, 2025
852dc89
check that the return value units are attached properly
keewis Aug 10, 2025
2205297
use `ureg.Quantity` instead of `unit.m_from`
keewis Aug 10, 2025
4d3edcd
check that return values can not have units
keewis Aug 10, 2025
19a5d72
check that functions can not return a result
keewis Aug 10, 2025
0cb6aff
check for various errors when returning results
keewis Aug 10, 2025
1e01f1a
don't cover the version fallback
keewis Aug 10, 2025
b45181a
add api docs
keewis Aug 10, 2025
6f1dce7
errors → error
keewis Aug 10, 2025
449181d
check the error type
keewis Aug 10, 2025
77f9cd3
change the raised error to `TypeError`
keewis Aug 10, 2025
a14a405
copy the docstring from #143
keewis Aug 10, 2025
5a6dcb4
styling
keewis Aug 10, 2025
69f62d2
see also
keewis Aug 10, 2025
911987b
changelog
keewis Aug 10, 2025
785bee4
terminology
keewis Aug 10, 2025
0e0897c
extend the tests
keewis Aug 15, 2025
68973ed
raise an error if the return value is unexpectedly `None`
keewis Aug 15, 2025
1fcbebc
more explicitly select the error types
keewis Aug 15, 2025
b13a17d
more tests
keewis Aug 15, 2025
36fdb41
add a dev env
keewis Aug 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ DataArray
xarray.DataArray.pint.bfill
xarray.DataArray.pint.interpolate_na

Wrapping quantity-unaware functions
-----------------------------------
.. autosummary::
:toctree: generated/

pint_xarray.expects

Testing
-------

Expand Down
3 changes: 2 additions & 1 deletion docs/terminology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Terminology

unit-like
A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`.
May be either a :py:class:`str` or a :py:class:`pint.Unit` instance.
May be a :py:class:`str`, a :py:class:`pint.Unit` instance or
:py:obj:`None`.

.. _pint: https://pint.readthedocs.io/en/stable
2 changes: 2 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ What's new
By `Justus Magin <https://github.com/keewis>`_.
- Switch to using pixi for all dependency management (:pull:`314`).
By `Justus Magin <https://github.com/keewis>`_.
- Added the :py:func:`pint_xarray.expects` decorator to allow wrapping quantity-unaware functions (:issue:`141`, :pull:`316`).
By `Justus Magin <https://github.com/keewis>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.

0.5.1 (10 Aug 2025)
-------------------
Expand Down
4 changes: 3 additions & 1 deletion pint_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pint

from pint_xarray import accessors, formatting, testing # noqa: F401
from pint_xarray._expects import expects
from pint_xarray.accessors import default_registry as unit_registry
from pint_xarray.accessors import setup_registry
from pint_xarray.index import PintIndex

try:
__version__ = version("pint-xarray")
except Exception:
except Exception: # pragma: no cover
# Local copy or not installed with setuptools.
# Disable minimum version checks on downstream libraries.
__version__ = "999"
Expand All @@ -23,4 +24,5 @@
"unit_registry",
"setup_registry",
"PintIndex",
"expects",
]
260 changes: 260 additions & 0 deletions pint_xarray/_expects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import functools
import inspect
import itertools
from inspect import Parameter

import pint
import pint.testing
import xarray as xr

from pint_xarray.accessors import get_registry
from pint_xarray.conversion import extract_units
from pint_xarray.itertools import zip_mappings

variable_parameters = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)


def _number_of_results(result):
if isinstance(result, tuple):
return len(result)
elif result is None:
return 0
else:
return 1


def expects(*args_units, return_value=None, **kwargs_units):
"""
Decorator which ensures the inputs and outputs of the decorated
function are expressed in the expected units.

Arguments to the decorated function are checked for the specified
units, converting to those units if necessary, and then stripped
of their units before being passed into the undecorated
function. Therefore the undecorated function should expect
unquantified DataArrays, Datasets, or numpy-like arrays, but with
the values expressed in specific units.

Parameters
----------
func : callable
Function to decorate, which accepts zero or more
xarray.DataArrays or numpy-like arrays as inputs, and may
optionally return one or more xarray.DataArrays or numpy-like
arrays.
*args_units : unit-like or mapping of hashable to unit-like, optional
Units to expect for each positional argument given to func.

The decorator will first check that arguments passed to the
decorated function possess these specific units (or will
attempt to convert the argument to these units), then will
strip the units before passing the magnitude to the wrapped
function.

A value of None indicates not to check that argument for units
(suitable for flags and other non-data arguments).
return_value : unit-like or list of unit-like or mapping of hashable to unit-like \
or list of mapping of hashable to unit-like, optional
The expected units of the returned value(s), either as a
single unit or as a list of units. The decorator will attach
these units to the variables returned from the function.

A value of None indicates not to attach any units to that
return value (suitable for flags and other non-data results).
**kwargs_units : mapping of hashable to unit-like, optional
Unit to expect for each keyword argument given to func.

The decorator will first check that arguments passed to the decorated
function possess these specific units (or will attempt to convert the
argument to these units), then will strip the units before passing the
magnitude to the wrapped function.

A value of None indicates not to check that argument for units (suitable
for flags and other non-data arguments).

Returns
-------
return_values : Any
Return values of the wrapped function, either a single value or a tuple
of values. These will be given units according to ``return_value``.

Raises
------
TypeError
If any of the units are not a valid type.
ValueError
If the number of arguments or return values does not match the number of
units specified. Also thrown if any parameter does not have a unit
specified.

See Also
--------
pint.wraps

Examples
--------
Decorating a function which takes one quantified input, but
returns a non-data value (in this case a boolean).

>>> @expects("deg C")
... def above_freezing(temp):
... return temp > 0
...

Decorating a function which allows any dimensions for the array, but also
accepts an optional `weights` keyword argument, which must be dimensionless.

>>> @expects(None, weights="dimensionless")
... def mean(da, weights=None):
... if weights:
... return da.weighted(weights=weights).mean()
... else:
... return da.mean()
...
"""

def outer(func):
signature = inspect.signature(func)

params_units = signature.bind(*args_units, **kwargs_units)

missing_params = [
name
for name, p in signature.parameters.items()
if p.kind not in variable_parameters and name not in params_units.arguments
]
if missing_params:
raise ValueError(
"Missing units for the following parameters: "
+ ", ".join(map(repr, missing_params))
)

n_expected_results = _number_of_results(return_value)

@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal return_value

params = signature.bind(*args, **kwargs)
# don't apply defaults, as those can't be quantities and thus must
# already be in the correct units

spec_units = dict(
enumerate(
itertools.chain.from_iterable(
spec.values() if isinstance(spec, dict) else (spec,)
for spec in params_units.arguments.values()
if spec is not None
)
)
)
params_units_ = dict(
enumerate(
itertools.chain.from_iterable(
(
extract_units(param)
if isinstance(param, (xr.DataArray, xr.Dataset))
else (param.units,)
)
for name, param in params.arguments.items()
if isinstance(param, (xr.DataArray, xr.Dataset, pint.Quantity))
)
)
)

ureg = get_registry(
None,
dict(spec_units) if spec_units else {},
dict(params_units_) if params_units else {},
)

errors = []
for name, (value, units) in zip_mappings(
params.arguments, params_units.arguments
):
try:
if units is None:
if isinstance(value, pint.Quantity) or (
isinstance(value, (xr.DataArray, xr.Dataset))
and value.pint.units
):
raise TypeError(
"Passed in a quantity where none was expected"
)
continue
if isinstance(value, pint.Quantity):
params.arguments[name] = value.m_as(units)
elif isinstance(value, (xr.DataArray, xr.Dataset)):
params.arguments[name] = value.pint.to(units).pint.dequantify()
else:
raise TypeError(
f"Attempting to convert non-quantity {value} to {units}."
)
except (
TypeError,
pint.errors.UndefinedUnitError,
pint.errors.DimensionalityError,
) as e:
e.add_note(
f"expects: raised while trying to convert parameter {name}"
)
errors.append(e)

if errors:
raise ExceptionGroup("Errors while converting parameters", errors)

result = func(*params.args, **params.kwargs)

n_results = _number_of_results(result)
if return_value is not None and (
(isinstance(result, tuple) ^ isinstance(return_value, tuple))
or (n_results != n_expected_results)
):
message = "mismatched number of return values:"
if n_results != n_expected_results:
message += f" expected {n_expected_results} but got {n_results}."
elif isinstance(result, tuple) and not isinstance(return_value, tuple):
message += (
" expected a single return value but got a 1-sized tuple."
)
else:
message += (
" expected a 1-sized tuple but got a single return value."
)
raise ValueError(message)

if result is None:
return

if not isinstance(result, tuple):
result = (result,)
if not isinstance(return_value, tuple):
return_value = (return_value,)

final_result = []
errors = []
for index, (value, units) in enumerate(zip(result, return_value)):
if units is not None:
try:
if isinstance(value, (xr.Dataset, xr.DataArray)):
value = value.pint.quantify(units)
else:
value = ureg.Quantity(value, units)
except Exception as e:
e.add_note(
f"expects: raised while trying to convert return value {index}"
)
errors.append(e)

final_result.append(value)

if errors:
raise ExceptionGroup("Errors while converting return values", errors)

if n_results == 1:
return final_result[0]
return tuple(final_result)

return wrapper

return outer
30 changes: 30 additions & 0 deletions pint_xarray/itertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import itertools
from functools import reduce


def separate(predicate, iterable):
evaluated = ((predicate(el), el) for el in iterable)

key = lambda x: x[0]
grouped = itertools.groupby(sorted(evaluated, key=key), key=key)

groups = {label: [el for _, el in group] for label, group in grouped}

return groups[False], groups[True]


def unique(iterable):
return list(dict.fromkeys(iterable))


def zip_mappings(*mappings):
def common_keys(a, b):
all_keys = unique(itertools.chain(a.keys(), b.keys()))
intersection = set(a.keys()).intersection(b.keys())

return [key for key in all_keys if key in intersection]

keys = list(reduce(common_keys, mappings))

for key in keys:
yield key, tuple(m[key] for m in mappings)
Loading
Loading