-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added ZeroSumNormal Distribution #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -65,6 +65,7 @@ | |||||||||||||||||
"Lognormal", | ||||||||||||||||||
"ChiSquared", | ||||||||||||||||||
"HalfNormal", | ||||||||||||||||||
"ZeroSumNormal", | ||||||||||||||||||
"Wald", | ||||||||||||||||||
"Pareto", | ||||||||||||||||||
"InverseGamma", | ||||||||||||||||||
|
@@ -924,6 +925,73 @@ def logcdf(self, value): | |||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class ZeroSumNormal(Continuous): | ||||||||||||||||||
def __new__(cls, name, *args, **kwargs): | ||||||||||||||||||
zerosum_axes = kwargs.get("zerosum_axes", None) | ||||||||||||||||||
zerosum_dims = kwargs.get("zerosum_dims", None) | ||||||||||||||||||
dims = kwargs.get("dims", None) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_dims, str): | ||||||||||||||||||
zerosum_dims = (zerosum_dims,) | ||||||||||||||||||
if isinstance(dims, str): | ||||||||||||||||||
dims = (dims,) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_dims is not None: | ||||||||||||||||||
if dims is None: | ||||||||||||||||||
raise ValueError("zerosum_dims can only be used with the dims kwargs.") | ||||||||||||||||||
if zerosum_axes is not None: | ||||||||||||||||||
raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") | ||||||||||||||||||
zerosum_axes = [] | ||||||||||||||||||
for dim in zerosum_dims: | ||||||||||||||||||
zerosum_axes.append(dims.index(dim)) | ||||||||||||||||||
kwargs["zerosum_axes"] = zerosum_axes | ||||||||||||||||||
|
||||||||||||||||||
return super().__new__(cls, name, *args, **kwargs) | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs): | ||||||||||||||||||
shape = kwargs.get("shape", ()) | ||||||||||||||||||
if isinstance(shape, int): | ||||||||||||||||||
shape = (shape,) | ||||||||||||||||||
|
||||||||||||||||||
self.mu = self.median = self.mode = tt.zeros(shape) | ||||||||||||||||||
self.sigma = tt.as_tensor_variable(sigma) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_axes is None: | ||||||||||||||||||
if shape: | ||||||||||||||||||
zerosum_axes = (-1,) | ||||||||||||||||||
else: | ||||||||||||||||||
zerosum_axes = () | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes no sense to have a |
||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_axes, int): | ||||||||||||||||||
zerosum_axes = (zerosum_axes,) | ||||||||||||||||||
|
||||||||||||||||||
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
if "transform" not in kwargs or kwargs["transform"] is None: | ||||||||||||||||||
kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) | ||||||||||||||||||
|
||||||||||||||||||
super().__init__(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
def logp(self, value): | ||||||||||||||||||
return Normal.dist(sigma=self.sigma).logp(value) | ||||||||||||||||||
|
return Normal.dist(sigma=self.sigma).logp(value) | |
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes] | |
return bound( | |
pm.Normal.dist(sigma=self.sigma).logp(x), | |
tt.all(self.sigma > 0), | |
broadcast_conditions=False, | |
*zerosums, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the logp
that we are using in the distribution matches it. The expected logp would look something like this:
def pseudo_log_det(A, tol=1e-13):
v, w = np.linalg.eigh(A)
return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
psdet = 0.5 * pseudo_log_det(2 * np.pi * cov)
exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a few tests with the logp and it looks like the logp
that we are using in this PR, doesn't match what one would expect from a degenerate multivariate normal distribution. In my comment above, I posted what a degenerate MvNormal logp
looks like. For this particular problem, where we know that we have only one eigenvector with zero eigenvalue, we can re-write the logp as:
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
v, w = np.linalg.eigh(cov)
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)
This is different from the logp
that we are currently using in this PR. The difference is in the normalization constant:
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
. In particular, since, all eigenvalues v
except the first one are the same and are equal to sigma**2
, psdet = (n - 1) * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))
. Whereas, with the assumed pm.Normal.dist(sigma=self.sigma).logp(x)
the normalization factor we are getting is:
psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))
This means that we have to multiply the logp
that we are using by (n-1)/n (in the case where only one axis sums to zero) to get the correct log probability density. I'll check what happens when more than one axes has to zerosum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that
zerosum_dims
is not used in__init__
, but if I don't put it here, it doesn't seem to be passed on to__new__
:TypeError: __init__() got an unexpected keyword argument 'zerosum_dims'
Not sure we can do it otherwise though. If someone has a better idea, I'm all ears
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the zerosum_dims is probably still in kwargs from line 949? We could just remove there.