-
Notifications
You must be signed in to change notification settings - Fork 70
Make basic INLA interface and simple marginalisation routine #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Michal-Novomestsky
wants to merge
67
commits into
pymc-devs:main
Choose a base branch
from
Michal-Novomestsky:implement-pmx.fit-option-for-INLA-+-marginalisation-routine
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 54 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
51afb9f
changed laplace approx to return MvNormal
Michal-Novomestsky c326525
added seperate line for evaluating Q-hess
Michal-Novomestsky 61d4d89
WIP: minor refactor
Michal-Novomestsky 1960cb9
started writing fit_INLA routine
Michal-Novomestsky 6a1d523
changed minimizer tol to 1e-8
Michal-Novomestsky 674d813
WIP: MarginalLaplaceRV
Michal-Novomestsky 3b5d49c
WIP: Minimize inside logp
Michal-Novomestsky 22d2ef1
tidied up MarginalLaplaceRV
Michal-Novomestsky c49de10
refactor: variable name change
Michal-Novomestsky 54e394d
jesse minimize testing
Michal-Novomestsky f02e652
end-to-end implementation
Michal-Novomestsky 9fb860e
refactor: changed boolean logic
Michal-Novomestsky 68b87ee
refactor: changed distributions in test case
Michal-Novomestsky de2d1fc
removed jesse's debug notebook
Michal-Novomestsky 787a39e
added WIP warning to pmx.fit
Michal-Novomestsky 71f8642
refactor: added TODO
Michal-Novomestsky 18747d5
refactor: re-ran notebook
Michal-Novomestsky c6010f3
refactor: temporarily changed gitignore
Michal-Novomestsky a473e87
refactor: rolled gitignore back to default
Michal-Novomestsky 31072ef
refactor: reworded list comprehension in log_likelihood
Michal-Novomestsky 6630675
refactor: uncommented import
Michal-Novomestsky b275e10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a92ba8f
removed legacy code
Michal-Novomestsky f077250
refactor: restored missing assert
Michal-Novomestsky 57a7935
refactor: changed test_inla.py location
Michal-Novomestsky 8b94a99
refactor: moved _precision_mv_normal_logp into pmx
Michal-Novomestsky bfa4e12
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky dd54a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 34dfdfa
set d automatically
Michal-Novomestsky 8cb19d9
refactor: removed inccorect laplace.py and moved inla into seperate f…
Michal-Novomestsky 0ee1ec9
bugfix: laplace import/file location
Michal-Novomestsky b82c6a4
refactor: folder name change
Michal-Novomestsky c5f2bd8
bugfix: removed erroneous test case
Michal-Novomestsky 9d7342d
bugfix: typo in INLA
Michal-Novomestsky 12b109f
refactor: added more __init__s
Michal-Novomestsky 92f6a0f
removed temp_kwargs, made Q amenable to RVs, removed dependency on Mv…
Michal-Novomestsky 296ca39
removed checking for MvNormal
Michal-Novomestsky dccd9a6
error message reworded
Michal-Novomestsky d0aaae5
added comments explaining logp bottleneck
Michal-Novomestsky af61cf7
removed None default for minimizer_kwargs
Michal-Novomestsky 0779b6e
added docstring for _precision_mv_normal_logp
Michal-Novomestsky d7b198a
added more documentation
Michal-Novomestsky 2198465
added example 1 to example notebook
Michal-Novomestsky 0c4fcd5
refactor: default return_latent_posteriors to false
Michal-Novomestsky d031008
Merge branch 'pymc-devs:main' into implement-pmx.fit-option-for-INLA-…
Michal-Novomestsky e7ccfe2
refactor: moved sample step inside if-block
Michal-Novomestsky ece57b1
added docstring
Michal-Novomestsky a675b37
added latex to docstring
Michal-Novomestsky 3636e98
refactored unittest
Michal-Novomestsky 47b8dae
refactor: moved laplace approx into seperate function + more docstrings
Michal-Novomestsky fb39764
refactor: TensorLike typehint
Michal-Novomestsky 065c6b2
refactor: labelling of p(x|y,params)
Michal-Novomestsky 59b623d
refactor: text in example notebook
Michal-Novomestsky 6cea8ba
removed old INLA notebook
Michal-Novomestsky 609156e
refactor: local import
Michal-Novomestsky bc3f1c3
latex-friendly formatting
Michal-Novomestsky e032c25
getting Q as RV
Michal-Novomestsky e367958
updated inla docstring
Michal-Novomestsky 8ed64fd
added warning (INLA experimental)
Michal-Novomestsky 7ca496b
added AR1 testcase
Michal-Novomestsky 154cc2c
added normals to notebook
Michal-Novomestsky fda71d6
refactor: changed test case atol to 0.2
Michal-Novomestsky 04db7c3
refactor: add warning to d calculation
Michal-Novomestsky 934d740
refactor: warning message
Michal-Novomestsky b3a3351
set vectorized jac flag to true
Michal-Novomestsky 19bc44d
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky 176ca6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pymc_extras.inference.INLA.inla import fit_INLA | ||
|
||
__all__ = ["fit_INLA"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import arviz as az | ||
import pymc as pm | ||
|
||
from pytensor.tensor import TensorLike, TensorVariable | ||
|
||
from pymc_extras.model.marginal.marginal_model import marginalize | ||
|
||
|
||
def fit_INLA( | ||
x: TensorVariable, | ||
Q: TensorLike, | ||
minimizer_seed: int = 42, | ||
model: pm.Model | None = None, | ||
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, | ||
return_latent_posteriors: bool = False, | ||
**sampler_kwargs, | ||
) -> az.InferenceData: | ||
r""" | ||
Performs inference over a linear mixed model using Integrated Nested Laplace Approximations (INLA). Assumes a model of the form: | ||
|
||
\begin{equation} | ||
\theta \rightarrow x \rightarrow y | ||
\end{equation} | ||
|
||
Where the prior on the hyperparameters $\pi(\theta)$ is arbitrary, the prior on the latent field is Gaussian (and in precision form): $\pi(x) = N(\mu, Q^{-1})$ and the latent field is linked to the observables $y$ through some linear map. | ||
|
||
As it stands, INLA in PyMC Extras has three main limitations: | ||
|
||
- Does not support inference over the latent field, only the hyperparameters. | ||
- Optimisation for $\mu^*$ is bottlenecked by calling `minimize`, and to a lesser extent, computing the hessian $f^"(x)$. | ||
- Does not offer sparse support which can provide significant speedups. | ||
Michal-Novomestsky marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
Michal-Novomestsky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
x: TensorVariable | ||
The latent gaussian to marginalize out. | ||
Q: TensorLike | ||
Precision matrix of the latent field. | ||
minimizer_seed: int | ||
Seed for random initialisation of the minimum point x*. | ||
model: pm.Model | ||
PyMC model. | ||
minimizer_kwargs: | ||
Kwargs to pass to pytensor.optimize.minimize during the optimization step maximizing logp(x | y, params). | ||
returned_latent_posteriors: | ||
If True, also return posteriors for the latent Gaussian field (currently unsupported). | ||
sampler_kwargs: | ||
Kwargs to pass to pm.sample. | ||
""" | ||
model = pm.modelcontext(model) | ||
Michal-Novomestsky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# TODO is there a better way to check if it's a RV? | ||
# print(vars(Q.owner)) | ||
# if isinstance(Q, TensorVariable) and "module" in vars(Q.owner): | ||
Q = model.rvs_to_values[Q] if isinstance(Q, TensorVariable) else Q | ||
Michal-Novomestsky marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Marginalize out the latent field | ||
marginalize_kwargs = { | ||
"Q": Q, | ||
"minimizer_seed": minimizer_seed, | ||
"minimizer_kwargs": minimizer_kwargs, | ||
} | ||
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs) | ||
|
||
# Sample over the hyperparameters | ||
if not return_latent_posteriors: | ||
idata = pm.sample(model=marginal_model, **sampler_kwargs) | ||
return idata | ||
|
||
# Unmarginalize stuff | ||
raise NotImplementedError( | ||
"Inference over the latent field with INLA is currently unsupported. Set return_latent_posteriors to False" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pymc_extras.inference.laplace_approx.laplace import fit_laplace | ||
|
||
__all__ = ["fit_laplace"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.