Skip to content
Open
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
51afb9f
changed laplace approx to return MvNormal
Michal-Novomestsky Jul 2, 2025
c326525
added seperate line for evaluating Q-hess
Michal-Novomestsky Jul 2, 2025
61d4d89
WIP: minor refactor
Michal-Novomestsky Jul 4, 2025
1960cb9
started writing fit_INLA routine
Michal-Novomestsky Jul 6, 2025
6a1d523
changed minimizer tol to 1e-8
Michal-Novomestsky Jul 6, 2025
674d813
WIP: MarginalLaplaceRV
Michal-Novomestsky Jul 16, 2025
3b5d49c
WIP: Minimize inside logp
Michal-Novomestsky Jul 19, 2025
22d2ef1
tidied up MarginalLaplaceRV
Michal-Novomestsky Aug 9, 2025
c49de10
refactor: variable name change
Michal-Novomestsky Aug 9, 2025
54e394d
jesse minimize testing
Michal-Novomestsky Aug 10, 2025
f02e652
end-to-end implementation
Michal-Novomestsky Aug 11, 2025
9fb860e
refactor: changed boolean logic
Michal-Novomestsky Aug 11, 2025
68b87ee
refactor: changed distributions in test case
Michal-Novomestsky Aug 12, 2025
de2d1fc
removed jesse's debug notebook
Michal-Novomestsky Aug 12, 2025
787a39e
added WIP warning to pmx.fit
Michal-Novomestsky Aug 12, 2025
71f8642
refactor: added TODO
Michal-Novomestsky Aug 12, 2025
18747d5
refactor: re-ran notebook
Michal-Novomestsky Aug 12, 2025
c6010f3
refactor: temporarily changed gitignore
Michal-Novomestsky Aug 12, 2025
a473e87
refactor: rolled gitignore back to default
Michal-Novomestsky Aug 12, 2025
31072ef
refactor: reworded list comprehension in log_likelihood
Michal-Novomestsky Aug 12, 2025
6630675
refactor: uncommented import
Michal-Novomestsky Aug 12, 2025
b275e10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
a92ba8f
removed legacy code
Michal-Novomestsky Aug 13, 2025
f077250
refactor: restored missing assert
Michal-Novomestsky Aug 13, 2025
57a7935
refactor: changed test_inla.py location
Michal-Novomestsky Aug 13, 2025
8b94a99
refactor: moved _precision_mv_normal_logp into pmx
Michal-Novomestsky Aug 16, 2025
bfa4e12
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky Aug 16, 2025
dd54a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2025
34dfdfa
set d automatically
Michal-Novomestsky Aug 17, 2025
8cb19d9
refactor: removed inccorect laplace.py and moved inla into seperate f…
Michal-Novomestsky Aug 17, 2025
0ee1ec9
bugfix: laplace import/file location
Michal-Novomestsky Aug 17, 2025
b82c6a4
refactor: folder name change
Michal-Novomestsky Aug 17, 2025
c5f2bd8
bugfix: removed erroneous test case
Michal-Novomestsky Aug 17, 2025
9d7342d
bugfix: typo in INLA
Michal-Novomestsky Aug 17, 2025
12b109f
refactor: added more __init__s
Michal-Novomestsky Aug 17, 2025
92f6a0f
removed temp_kwargs, made Q amenable to RVs, removed dependency on Mv…
Michal-Novomestsky Aug 26, 2025
296ca39
removed checking for MvNormal
Michal-Novomestsky Aug 26, 2025
dccd9a6
error message reworded
Michal-Novomestsky Aug 26, 2025
d0aaae5
added comments explaining logp bottleneck
Michal-Novomestsky Aug 26, 2025
af61cf7
removed None default for minimizer_kwargs
Michal-Novomestsky Aug 26, 2025
0779b6e
added docstring for _precision_mv_normal_logp
Michal-Novomestsky Aug 26, 2025
d7b198a
added more documentation
Michal-Novomestsky Aug 26, 2025
2198465
added example 1 to example notebook
Michal-Novomestsky Aug 27, 2025
0c4fcd5
refactor: default return_latent_posteriors to false
Michal-Novomestsky Aug 27, 2025
d031008
Merge branch 'pymc-devs:main' into implement-pmx.fit-option-for-INLA-…
Michal-Novomestsky Aug 27, 2025
e7ccfe2
refactor: moved sample step inside if-block
Michal-Novomestsky Aug 27, 2025
ece57b1
added docstring
Michal-Novomestsky Aug 27, 2025
a675b37
added latex to docstring
Michal-Novomestsky Aug 27, 2025
3636e98
refactored unittest
Michal-Novomestsky Aug 27, 2025
47b8dae
refactor: moved laplace approx into seperate function + more docstrings
Michal-Novomestsky Aug 27, 2025
fb39764
refactor: TensorLike typehint
Michal-Novomestsky Aug 27, 2025
065c6b2
refactor: labelling of p(x|y,params)
Michal-Novomestsky Aug 27, 2025
59b623d
refactor: text in example notebook
Michal-Novomestsky Aug 27, 2025
6cea8ba
removed old INLA notebook
Michal-Novomestsky Aug 27, 2025
609156e
refactor: local import
Michal-Novomestsky Aug 27, 2025
bc3f1c3
latex-friendly formatting
Michal-Novomestsky Aug 28, 2025
e032c25
getting Q as RV
Michal-Novomestsky Aug 28, 2025
e367958
updated inla docstring
Michal-Novomestsky Aug 28, 2025
8ed64fd
added warning (INLA experimental)
Michal-Novomestsky Aug 28, 2025
7ca496b
added AR1 testcase
Michal-Novomestsky Aug 28, 2025
154cc2c
added normals to notebook
Michal-Novomestsky Aug 28, 2025
fda71d6
refactor: changed test case atol to 0.2
Michal-Novomestsky Aug 28, 2025
04db7c3
refactor: add warning to d calculation
Michal-Novomestsky Aug 28, 2025
934d740
refactor: warning message
Michal-Novomestsky Aug 28, 2025
b3a3351
set vectorized jac flag to true
Michal-Novomestsky Sep 21, 2025
19bc44d
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky Sep 21, 2025
176ca6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
511 changes: 511 additions & 0 deletions notebooks/INLA Example.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pymc_extras/inference/INLA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_extras.inference.INLA.inla import fit_INLA

__all__ = ["fit_INLA"]
73 changes: 73 additions & 0 deletions pymc_extras/inference/INLA/inla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import arviz as az
import pymc as pm

from pytensor.tensor import TensorLike, TensorVariable

from pymc_extras.model.marginal.marginal_model import marginalize


def fit_INLA(
x: TensorVariable,
Q: TensorLike,
minimizer_seed: int = 42,
model: pm.Model | None = None,
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
return_latent_posteriors: bool = False,
**sampler_kwargs,
) -> az.InferenceData:
r"""
Performs inference over a linear mixed model using Integrated Nested Laplace Approximations (INLA). Assumes a model of the form:

\begin{equation}
\theta \rightarrow x \rightarrow y
\end{equation}

Where the prior on the hyperparameters $\pi(\theta)$ is arbitrary, the prior on the latent field is Gaussian (and in precision form): $\pi(x) = N(\mu, Q^{-1})$ and the latent field is linked to the observables $y$ through some linear map.

As it stands, INLA in PyMC Extras has three main limitations:

- Does not support inference over the latent field, only the hyperparameters.
- Optimisation for $\mu^*$ is bottlenecked by calling `minimize`, and to a lesser extent, computing the hessian $f^"(x)$.
- Does not offer sparse support which can provide significant speedups.

Parameters
----------
x: TensorVariable
The latent gaussian to marginalize out.
Q: TensorLike
Precision matrix of the latent field.
minimizer_seed: int
Seed for random initialisation of the minimum point x*.
model: pm.Model
PyMC model.
minimizer_kwargs:
Kwargs to pass to pytensor.optimize.minimize during the optimization step maximizing logp(x | y, params).
returned_latent_posteriors:
If True, also return posteriors for the latent Gaussian field (currently unsupported).
sampler_kwargs:
Kwargs to pass to pm.sample.
"""
model = pm.modelcontext(model)

# TODO is there a better way to check if it's a RV?
# print(vars(Q.owner))
# if isinstance(Q, TensorVariable) and "module" in vars(Q.owner):
Q = model.rvs_to_values[Q] if isinstance(Q, TensorVariable) else Q

# Marginalize out the latent field
marginalize_kwargs = {
"Q": Q,
"minimizer_seed": minimizer_seed,
"minimizer_kwargs": minimizer_kwargs,
}
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs)

# Sample over the hyperparameters
if not return_latent_posteriors:
idata = pm.sample(model=marginal_model, **sampler_kwargs)
return idata

# Unmarginalize stuff
raise NotImplementedError(
"Inference over the latent field with INLA is currently unsupported. Set return_latent_posteriors to False"
)
3 changes: 2 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

from pymc_extras.inference.fit import fit
from pymc_extras.inference.INLA.inla import fit_INLA
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]
14 changes: 12 additions & 2 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ def fit(method: str, **kwargs) -> az.InferenceData:

return fit_pathfinder(**kwargs)

if method == "laplace":
from pymc_extras.inference import fit_laplace
elif method == "laplace":
from pymc_extras.inference.laplace_approx import fit_laplace

return fit_laplace(**kwargs)

elif method == "INLA":
from pymc_extras.inference.INLA import fit_INLA

return fit_INLA(**kwargs)

else:
raise ValueError(
f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'."
)
3 changes: 3 additions & 0 deletions pymc_extras/inference/laplace_approx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_extras.inference.laplace_approx.laplace import fit_laplace

__all__ = ["fit_laplace"]
101 changes: 0 additions & 101 deletions pymc_extras/inference/laplace_approx/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,22 @@

import logging

from collections.abc import Callable
from functools import partial
from typing import Literal
from typing import cast as type_cast

import arviz as az
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import xarray as xr

from better_optimize.constants import minimize_method
from numpy.typing import ArrayLike
from pymc.blocking import DictToArrayBijection
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import join_nonshared_inputs
from pymc.util import get_default_varnames
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.optimize import minimize
from pytensor.tensor.type import Variable

from pymc_extras.inference.laplace_approx.find_map import (
Expand All @@ -51,102 +46,6 @@
_log = logging.getLogger(__name__)


def get_conditional_gaussian_approximation(
x: TensorVariable,
Q: TensorVariable | ArrayLike,
mu: TensorVariable | ArrayLike,
args: list[TensorVariable] | None = None,
model: pm.Model | None = None,
method: minimize_method = "BFGS",
use_jac: bool = True,
use_hess: bool = False,
optimizer_kwargs: dict | None = None,
) -> Callable:
"""
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.

That is:
y | x, sigma ~ N(Ax, sigma^2 W)
x | params ~ N(mu, Q(params)^-1)

We seek to estimate log(p(x | y, params)):

log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const

Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).

This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.

Thus:

1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.

2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).

Parameters
----------
x: TensorVariable
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
Q: TensorVariable | ArrayLike
The precision matrix of the latent field x.
mu: TensorVariable | ArrayLike
The mean of the latent field x.
args: list[TensorVariable]
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
model: Model
PyMC model to use.
method: minimize_method
Which minimization algorithm to use.
use_jac: bool
If true, the minimizer will compute the gradient of log(p(x | y, params)).
use_hess: bool
If true, the minimizer will compute the Hessian log(p(x | y, params)).
optimizer_kwargs: dict
Kwargs to pass to scipy.optimize.minimize.

Returns
-------
f: Callable
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
"""
model = pm.modelcontext(model)

if args is None:
args = model.continuous_value_vars + model.discrete_value_vars

# f = log(p(y | x, params))
f_x = model.logp()
jac = pytensor.gradient.grad(f_x, x)
hess = pytensor.gradient.jacobian(jac.flatten(), x)

# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)

# Maximize log(p(x | y, params)) wrt x to find mode x0
x0, _ = minimize(
objective=-log_x_posterior,
x=x,
method=method,
jac=use_jac,
hess=use_hess,
optimizer_kwargs=optimizer_kwargs,
)

# require f'(x0) and f''(x0) for Laplace approx
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})

# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
_, logdetQ = pt.nlinalg.slogdet(Q)
conditional_gaussian_approx = (
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
)

# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
# far from the mode x0 or in a neighbourhood which results in poor convergence.
return pytensor.function(args, [x0, conditional_gaussian_approx])


def _unconstrained_vector_to_constrained_rvs(model):
outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
constrained_names = [
Expand Down
Loading
Loading