- 
                Notifications
    
You must be signed in to change notification settings  - Fork 2.1k
 
          Add ZeroSumNormal distribution
          #6121
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
6260c84
              af96016
              71e5651
              3cadb26
              a66c586
              e3be495
              759de36
              a5a1e45
              c9eea6e
              0582d7c
              0bdcdd7
              854ef4c
              fd3aefa
              e94e4f1
              dec4a9f
              f7a55c5
              da6eaab
              a5ed1f0
              126e76b
              3a8d898
              4c52737
              7e4ed0a
              44b5b91
              99dbb38
              e3dc1d4
              09f0d91
              cf5b384
              3e86a3e
              ce68f02
              09d849c
              b50909e
              c204131
              7ba1d0f
              5ee950a
              95ffc94
              13a54e6
              ca655bc
              9d419ef
              85da56c
              f363118
              64eca5c
              c5e76c9
              08c9df0
              c120f7e
              ba5f3a1
              48dafe9
              6612a24
              6b07a2a
              135ed47
              cba0187
              566f308
              5954e65
              3e72922
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -69,10 +69,13 @@ def polyagamma_cdf(*args, **kwargs): | |
| raise RuntimeError("polyagamma package is not installed!") | ||
| 
     | 
||
| 
     | 
||
| from numpy.core.numeric import normalize_axis_tuple | ||
| from scipy import stats | ||
| from scipy.interpolate import InterpolatedUnivariateSpline | ||
| from scipy.special import expit | ||
| 
     | 
||
| import pymc as pm | ||
| 
     | 
||
| from pymc.aesaraf import floatX | ||
| from pymc.distributions import transforms | ||
| from pymc.distributions.dist_math import ( | ||
| 
        
          
        
         | 
    @@ -86,16 +89,28 @@ def polyagamma_cdf(*args, **kwargs): | |
| normal_lcdf, | ||
| zvalue, | ||
| ) | ||
| from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous | ||
| from pymc.distributions.shape_utils import rv_size_is_none | ||
| from pymc.distributions.transforms import _default_transform | ||
| from pymc.distributions.distribution import ( | ||
| DIST_PARAMETER_TYPES, | ||
| Continuous, | ||
| Distribution, | ||
| SymbolicRandomVariable, | ||
| _moment, | ||
| ) | ||
| from pymc.distributions.logprob import ignore_logprob | ||
| from pymc.distributions.shape_utils import ( | ||
| _change_dist_size, | ||
| convert_dims, | ||
| rv_size_is_none, | ||
| ) | ||
| from pymc.distributions.transforms import ZeroSumTransform, _default_transform | ||
| from pymc.math import invlogit, logdiffexp, logit | ||
| 
     | 
||
| __all__ = [ | ||
| "Uniform", | ||
| "Flat", | ||
| "HalfFlat", | ||
| "Normal", | ||
| "ZeroSumNormal", | ||
| "TruncatedNormal", | ||
| "Beta", | ||
| "Kumaraswamy", | ||
| 
          
            
          
           | 
    @@ -585,6 +600,172 @@ def logcdf(value, mu, sigma): | |
| ) | ||
| 
     | 
||
| 
     | 
||
| class ZeroSumNormalRV(SymbolicRandomVariable): | ||
| """ZeroSumNormal random variable""" | ||
| 
     | 
||
| _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") | ||
| zerosum_axes = None | ||
| 
     | 
||
| def __init__(self, *args, zerosum_axes, **kwargs): | ||
| self.zerosum_axes = zerosum_axes | ||
| super().__init__(*args, **kwargs) | ||
| 
     | 
||
| 
     | 
||
| class ZeroSumNormal(Distribution): | ||
| r""" | ||
| ZeroSumNormal distribution, i.e Normal distribution where one or | ||
| several axes are constrained to sum to zero. | ||
| By default, the last axis is constrained to sum to zero. | ||
| See `zerosum_axes` kwarg for more details. | ||
| 
     | 
||
| Parameters | ||
| ---------- | ||
| sigma : tensor_like of float | ||
| Standard deviation (sigma > 0). | ||
| Defaults to 1 if not specified. | ||
| For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. | ||
| zerosum_axes: list or tuple of strings or integers | ||
| Axis (or axes) along which the zero-sum constraint is enforced. | ||
| Defaults to [-1], i.e the last axis. | ||
| If strings are passed, then ``dims`` is needed. | ||
| Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. | ||
| dims: list or tuple of strings, optional | ||
| The dimension names of the axes. | ||
| Necessary when ``zerosum_axes`` is specified with strings. | ||
| 
     | 
||
| Warnings | ||
| -------- | ||
| ``sigma`` has to be a scalar, to ensure the zero-sum constraint. | ||
| The ability to specifiy a vector of ``sigma`` may be added in future versions. | ||
| 
     | 
||
| Examples | ||
| -------- | ||
| .. code-block:: python | ||
| COORDS = { | ||
| "regions": ["a", "b", "c"], | ||
| "answers": ["yes", "no", "whatever", "don't understand question"], | ||
| } | ||
| with pm.Model(coords=COORDS) as m: | ||
| ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") | ||
| 
     | 
||
| with pm.Model(coords=COORDS) as m: | ||
| ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) | ||
| 
     | 
||
| with pm.Model(coords=COORDS) as m: | ||
| ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) | ||
| """ | ||
| rv_type = ZeroSumNormalRV | ||
| 
     | 
||
| def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): | ||
| dims = convert_dims(dims) | ||
| if zerosum_axes is None: | ||
| zerosum_axes = [-1] | ||
| if not isinstance(zerosum_axes, (list, tuple)): | ||
| zerosum_axes = [zerosum_axes] | ||
| 
     | 
||
| if isinstance(zerosum_axes[0], str): | ||
| if not dims: | ||
| raise ValueError("You need to specify dims if zerosum_axes are strings.") | ||
| else: | ||
| zerosum_axes_ = [] | ||
| for axis in zerosum_axes: | ||
| zerosum_axes_.append(dims.index(axis)) | ||
| zerosum_axes = zerosum_axes_ | ||
| 
     | 
||
| return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) | ||
| 
     | 
||
| @classmethod | ||
| def dist(cls, sigma=1, zerosum_axes=None, **kwargs): | ||
| if zerosum_axes is None: | ||
| zerosum_axes = [-1] | ||
| 
     | 
||
| sigma = at.as_tensor_variable(floatX(sigma)) | ||
| if sigma.ndim > 0: | ||
| raise ValueError("sigma has to be a scalar") | ||
| 
     | 
||
| return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) | ||
| 
     | 
||
| # TODO: This is if we want ZeroSum constraint on other dists than Normal | ||
                
       | 
||
| # def dist(cls, dist, lower, upper, **kwargs): | ||
| # if not isinstance(dist, TensorVariable) or not isinstance( | ||
| # dist.owner.op, (RandomVariable, SymbolicRandomVariable) | ||
| # ): | ||
| # raise ValueError( | ||
| # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" | ||
| # ) | ||
| # if dist.owner.op.ndim_supp > 0: | ||
| # raise NotImplementedError( | ||
| # "Censoring of multivariate distributions has not been implemented yet" | ||
| # ) | ||
| # check_dist_not_registered(dist) | ||
| # return super().dist([dist, lower, upper], **kwargs) | ||
| 
     | 
||
| @classmethod | ||
| def rv_op(cls, sigma, zerosum_axes, size=None): | ||
| if size is None: | ||
| zerosum_axes_ = np.asarray(zerosum_axes) | ||
| # just a placeholder size to infer minimum shape | ||
| size = np.ones( | ||
| max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int | ||
| ).tolist() | ||
| 
     | 
||
| # check if zerosum_axes is valid | ||
| normalize_axis_tuple(zerosum_axes, len(size)) | ||
| 
     | 
||
| normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) | ||
| normal_dist_, sigma_ = normal_dist.type(), sigma.type() | ||
| 
     | 
||
| # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes | ||
| zerosum_rv_ = normal_dist_ | ||
| for axis in zerosum_axes: | ||
| zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) | ||
| 
     | 
||
| return ZeroSumNormalRV( | ||
| inputs=[normal_dist_, sigma_], | ||
| outputs=[zerosum_rv_], | ||
| zerosum_axes=zerosum_axes, | ||
| ndim_supp=0, | ||
| )(normal_dist, sigma) | ||
| 
     | 
||
| 
     | 
||
| @_change_dist_size.register(ZeroSumNormalRV) | ||
| def change_zerosum_size(op, normal_dist, new_size, expand=False): | ||
                
      
                  AlexAndorra marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| normal_dist, sigma = normal_dist.owner.inputs | ||
| if expand: | ||
| new_size = tuple(new_size) + tuple(normal_dist.shape) | ||
| return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) | ||
| 
     | 
||
| 
     | 
||
| @_moment.register(ZeroSumNormalRV) | ||
                
      
                  AlexAndorra marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| def zerosumnormal_moment(op, rv, *rv_inputs): | ||
| return at.zeros_like(rv) | ||
| 
     | 
||
| 
     | 
||
| @_default_transform.register(ZeroSumNormalRV) | ||
| def zerosum_default_transform(op, rv): | ||
| return ZeroSumTransform(op.zerosum_axes) | ||
| 
     | 
||
| 
     | 
||
| @_logprob.register(ZeroSumNormalRV) | ||
                
      
                  AlexAndorra marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): | ||
| (value,) = values | ||
| shape = value.shape | ||
| _deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) | ||
| _full_size = at.prod(shape) | ||
| _degrees_of_freedom = at.prod(_deg_free_shape) | ||
| zerosums = [ | ||
| at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes | ||
| ] | ||
| # out = at.sum( | ||
| # pm.logp(dist, value) * _degrees_of_freedom / _full_size, | ||
| # axis=op.zerosum_axes, | ||
| # ) | ||
| # figure out how dimensionality should be handled for logp | ||
| # for now, we assume ZSN is a scalar distribut, which is not correct | ||
| out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size | ||
| return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") | ||
| 
     | 
||
| 
     | 
||
| class TruncatedNormalRV(RandomVariable): | ||
| name = "truncated_normal" | ||
| ndim_supp = 0 | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -27,6 +27,10 @@ | |
| from aesara.graph import Op | ||
| from aesara.tensor import TensorVariable | ||
| 
     | 
||
| # ignore mypy error because it somehow considers that | ||
| # "numpy.core.numeric has no attribute normalize_axis_tuple" | ||
| from numpy.core.numeric import normalize_axis_tuple # type: ignore | ||
| 
     | 
||
| __all__ = [ | ||
| "RVTransform", | ||
| "simplex", | ||
| 
        
          
        
         | 
    @@ -39,6 +43,7 @@ | |
| "circular", | ||
| "CholeskyCovPacked", | ||
| "Chain", | ||
| "ZeroSumTransform", | ||
                
      
                  AlexAndorra marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| ] | ||
| 
     | 
||
| 
     | 
||
| 
          
            
          
           | 
    @@ -266,6 +271,66 @@ def bounds_fn(*rv_inputs): | |
| super().__init__(args_fn=bounds_fn) | ||
| 
     | 
||
| 
     | 
||
| class ZeroSumTransform(RVTransform): | ||
| """ | ||
| Constrains the samples of a Normal distribution to sum to zero | ||
                
      
                  twiecki marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| along the user-provided ``zerosum_axes``. | ||
| By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed | ||
| on the last axis. | ||
                
      
                  ricardoV94 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| """ | ||
| 
     | 
||
| name = "zerosum" | ||
| 
     | 
||
| __props__ = ("zerosum_axes",) | ||
| 
     | 
||
| def __init__(self, zerosum_axes): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
| zerosum_axes : list of ints | ||
| Must be a list of integers (positive or negative). | ||
| By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed | ||
| on the last axis. | ||
| """ | ||
| self.zerosum_axes = zerosum_axes | ||
| 
     | 
||
| def forward(self, value, *rv_inputs): | ||
| for axis in self.zerosum_axes: | ||
| value = extend_axis_rev(value, axis=axis) | ||
| return value | ||
| 
     | 
||
| def backward(self, value, *rv_inputs): | ||
| for axis in self.zerosum_axes: | ||
| value = extend_axis(value, axis=axis) | ||
| return value | ||
| 
     | 
||
| def log_jac_det(self, value, *rv_inputs): | ||
| return at.constant(0.0) | ||
| 
     | 
||
| 
     | 
||
| def extend_axis(array, axis): | ||
| n = array.shape[axis] + 1 | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could maybe add a comment here saying that this is using a householder reflection plus a projection operator to move forward from the constrained space onto the zero sum manifold. I’ll look up our notes and write something here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you find your notes @lucianopaz ?  | 
||
| sum_vals = array.sum(axis, keepdims=True) | ||
| norm = sum_vals / (np.sqrt(n) + n) | ||
| fill_val = norm - sum_vals / np.sqrt(n) | ||
| 
     | 
||
| out = at.concatenate([array, fill_val], axis=axis) | ||
| return out - norm | ||
| 
     | 
||
| 
     | 
||
| def extend_axis_rev(array, axis): | ||
| normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] | ||
| 
     | 
||
| n = array.shape[normalized_axis] | ||
| last = at.take(array, [-1], axis=normalized_axis) | ||
| 
     | 
||
| sum_vals = -last * np.sqrt(n) | ||
| norm = sum_vals / (np.sqrt(n) + n) | ||
| slice_before = (slice(None, None),) * normalized_axis | ||
| 
     | 
||
| return array[slice_before + (slice(None, -1),)] + norm | ||
| 
     | 
||
| 
     | 
||
| log_exp_m1 = LogExpM1() | ||
| log_exp_m1.__doc__ = """ | ||
| Instantiation of :class:`pymc.distributions.transforms.LogExpM1` | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.