Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec
return log_evidence, log_occam_factor


@check_pdf_model_is_linear
def analytic_fit(
central_covmat_index,
_pred_data,
forward_map,
pdf_model,
analytic_settings,
prior_settings,
FIT_XGRID,
fast_kernel_arrays,
data,
):
"""
Analytic fits, for any *linear* PDF model.
Expand All @@ -106,8 +106,8 @@ def analytic_fit(
central_covmat_index: commondata_utils.CentralCovmatIndex
dataclass containing central values and covariance matrix.

_pred_data: @jax.jit CompiledFunction
Prediction function for the fit.
forward_map: @jax.jit CompiledFunction
Forward map function for the fit.

pdf_model: pdf_model.PDFModel
PDF model to fit.
Expand All @@ -124,22 +124,28 @@ def analytic_fit(

fast_kernel_arrays: tuple
Tuple containing the fast kernel arrays.

data: validphys.core.DataGroupSpec
The data group specification for the fit.
"""
# Ensure that the PDF model is linear before running the fit.
log.info("Checking that the PDF model is linear...")
check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data)

log.warning("The prior is assumed to be flat in the parameters.")
log.warning(
"Assuming that the prior is wide enough to fully cover the gaussian likelihood."
)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data)

# Precompute predictions for the basis of the model
bases = jnp.identity(len(parameters))
pdf_grid = pdf_model.grid_values_func(FIT_XGRID)
predictions = jnp.array(
[pred_and_pdf(basis, fast_kernel_arrays)[0] for basis in bases]
[forward_map(pdf_grid, fast_kernel_arrays, basis)[0] for basis in bases]
)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0]
intercept = forward_map(pdf_grid, fast_kernel_arrays, jnp.zeros(len(parameters)))[0]

# Construct the analytic solution
central_values = central_covmat_index.central_values
Expand Down
1 change: 1 addition & 0 deletions colibri/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"colibri.param_initialisation",
"colibri.export_results",
"colibri.closure_test",
"colibri.forward_map",
"reportengine.report",
]

Expand Down
8 changes: 3 additions & 5 deletions colibri/bayes_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
cast_to_numpy,
get_full_posterior,
)
from colibri.checks import check_pdf_models_equal
from colibri.core import BayesianPrior
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions


@check_pdf_models_equal
def bayesian_prior(prior_settings, pdf_model):
def bayesian_prior(prior_settings, forward_map):
"""
Produces a prior transform function.

Expand All @@ -31,8 +29,8 @@ def bayesian_prior(prior_settings, pdf_model):
prior_specs = prior_settings.prior_distribution_specs

if "bounds" in prior_specs:
# Use param names from the model to order bounds correctly
param_names = pdf_model.param_names
# Use param names from the forward map to order bounds correctly
param_names = forward_map.param_names
bounds_dict = prior_specs["bounds"]
missing = [p for p in param_names if p not in bounds_dict]
if missing:
Expand Down
12 changes: 6 additions & 6 deletions colibri/blackjax_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def blackjax_fit(
pdf_model,
forward_map,
bayesian_prior,
blackjax_settings,
log_likelihood,
Expand All @@ -47,8 +47,8 @@ def blackjax_fit(

Parameters
----------
pdf_model: pdf_model.PDFModel
The PDF model to fit.
forward_map: ForwardMap
The forward map whose ``param_names`` enumerate all fit parameters.

bayesian_prior: BayesianPrior, @jax.jit CompiledFunction
The prior function for the model.
Expand All @@ -70,7 +70,7 @@ def blackjax_fit(
# set the BlackJAX seed
rng_key = jax.random.PRNGKey(blackjax_settings["seed"])
log.info(f"BlackJAX initialisation seed: {rng_key}")
n_dims = pdf_model.n_parameters
n_dims = len(forward_map.param_names)
n_live = blackjax_settings["n_live"]
n_delete = int(blackjax_settings["delete_fraction"] * n_live)

Expand Down Expand Up @@ -141,7 +141,7 @@ def one_step(carry, xs):
data=final_states.particles,
logL=final_states.loglikelihood,
logL_birth=final_states.loglikelihood_birth,
columns=pdf_model.param_names,
columns=forward_map.param_names,
)
# write nested_samples.csv to blackjax_logs
log_dir = blackjax_settings["log_dir"]
Expand All @@ -167,7 +167,7 @@ def one_step(carry, xs):
"logZ_err": logzs.std(),
"ess": ess_value,
},
param_names=pdf_model.param_names,
param_names=forward_map.param_names,
resampled_posterior=resampled_posterior,
full_posterior_samples=full_samples,
bayesian_metrics={
Expand Down
26 changes: 13 additions & 13 deletions colibri/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,26 @@
import jax
from colibri.theory_predictions import make_pred_data, fast_kernel_arrays

from colibri.utils import get_fit_path, get_pdf_model, pdf_models_equal
from colibri.utils import get_fit_path, get_pdf_model


@make_argcheck
def check_pdf_models_equal(prior_settings, pdf_model, theoryid):
def check_pdf_models_equal(prior_settings, forward_map, theoryid):
"""
Decorator that can be added to functions to check that the
PDF model used as prior (eg when using prior_settings["type"] == "prior_from_gauss_posterior")
matches the PDF model used in the current fit (pdf_model).
matches the PDF model used in the current fit (via ``forward_map.pdf_param_names``).
"""

if prior_settings.prior_distribution == "prior_from_gauss_posterior":

prior_fit = prior_settings.prior_distribution_specs["prior_fit"]
prior_pdf_model = get_pdf_model(prior_fit)

if not pdf_models_equal(prior_pdf_model, pdf_model):
if prior_pdf_model.param_names != list(forward_map.pdf_param_names):
raise ValueError(
f"PDF model {pdf_model} does not match prior settings {prior_pdf_model}"
f"PDF param names from forward_map {list(forward_map.pdf_param_names)} "
f"do not match prior PDF model param names {prior_pdf_model.param_names}"
)

# load filter.yml runcard of the prior fit
Expand All @@ -41,8 +42,7 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid):
)


@make_argcheck
def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data):
"""
Decorator that can be added to functions to check that the
PDF model is linear.
Expand All @@ -52,8 +52,8 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
fk = fast_kernel_arrays(data, FIT_XGRID)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=pred_data)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fk)[0]
pdf_grid = pdf_model.grid_values_func(FIT_XGRID)
intercept, _ = forward_map(pdf_grid, fk, jnp.zeros(len(parameters)))

# Run the check for 10 random points in the parameter space
for i in range(10):
Expand All @@ -65,16 +65,16 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):

# Test additivity
add_check = jnp.isclose(
pred_and_pdf(x1, fk)[0] + pred_and_pdf(x2, fk)[0],
pred_and_pdf(x1 + x2, fk)[0] + intercept,
forward_map(pdf_grid, fk, x1)[0] + forward_map(pdf_grid, fk, x2)[0],
forward_map(pdf_grid, fk, x1 + x2)[0] + intercept,
)

# Test homogeneity
c = jax.random.uniform(key, (1,))

homogeneity_check = jnp.isclose(
c * (pred_and_pdf(x1, fk)[0] - intercept),
pred_and_pdf(c * x1, fk)[0] - intercept,
c * (forward_map(pdf_grid, fk, x1)[0] - intercept),
forward_map(pdf_grid, fk, c * x1)[0] - intercept,
)

if not add_check.all() or not homogeneity_check.all():
Expand Down
3 changes: 2 additions & 1 deletion colibri/export_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def write_replicas(

# Create the exportgrid
lhapdf_interpolator = pdf_model.grid_values_func(xgrid)
n_pdf_params = len(pdf_model.param_names)

# Finish by writing the replicas to export grids, ready for evolution
for i in indices_per_process:
parameters = jnp.array(bayes_fit.resampled_posterior[i, :])
parameters = jnp.array(bayes_fit.resampled_posterior[i, :n_pdf_params])
grid_for_writing = np.array(lhapdf_interpolator(parameters))

replica_index = i + 1
Expand Down
Loading
Loading