Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b8c33dd
Refactored slice_plot to support different projections (3D and contour)
shammeer-s Apr 12, 2025
8dcb9e2
Refactored codebase to decouple data processing logic from visualizat…
shammeer-s Apr 13, 2025
8e2bfe0
Enhance sliceplot with comprehensive docstrings and minor refactoring
shammeer-s Apr 13, 2025
8b4cc8c
Minor bug fixes
shammeer-s Apr 13, 2025
9316ea8
Merge branch 'optimagic-dev:main' into visualization
shammeer-s Apr 23, 2025
7782546
Slice plot 3D implementation in a sandbox version
shammeer-s Apr 24, 2025
7108c7d
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s Apr 24, 2025
1fdde7f
Slice plot 3D implementation in a sandbox version
shammeer-s Apr 27, 2025
8087689
Slice plot 3D implementation in a sandbox version
shammeer-s Apr 27, 2025
834a779
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2025
1164bd4
Add notebook for visualizing 3D slice plots and update sandbox imports
shammeer-s Apr 28, 2025
e995aa3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
300a460
Enhance 3D slice plot visualization notebook and refactor plotting fu…
shammeer-s Apr 28, 2025
5ba721c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
1c8243f
Documentation strings are updated for all functions in slice_plot_3d.…
shammeer-s Apr 28, 2025
14d5a32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
5cdfb56
Merge branch 'main' into visualization
timmens Apr 29, 2025
2a53a89
Minor changes according to previous slice_plot.py file logic
shammeer-s May 4, 2025
382cd15
Minor fixes with code login on evaluating kwargs
shammeer-s May 4, 2025
b3caa04
Minor fixes with code login on evaluating kwargs
shammeer-s May 6, 2025
091eb80
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s May 6, 2025
a9405d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 6, 2025
5c08b18
Merge branch 'main' into visualization
timmens May 6, 2025
40464a1
Minor fixes with code login on evaluating kwargs
shammeer-s May 7, 2025
e7b9efa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2025
6bcf72f
Minor fixes
shammeer-s May 9, 2025
5a8af1c
Refactor slice_plot_3d.py to improve parameter handling and streamlin…
shammeer-s May 12, 2025
0335529
Merge branch 'optimagic-dev:main' into visualization
shammeer-s May 12, 2025
c88f69d
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s May 12, 2025
8d2d2e9
Test cases addition
shammeer-s May 12, 2025
072029d
Merge branch 'main' into visualization
timmens May 19, 2025
e1a838f
Merge branch 'main' into visualization
timmens May 19, 2025
da1055a
Enhance slice_plot_3d functionality with univariate and multivariate …
shammeer-s May 19, 2025
8011650
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s May 19, 2025
a421d74
Minor Fixes
shammeer-s May 19, 2025
f6b9174
Minor document fixes
shammeer-s May 19, 2025
f05b5e4
Test cases fixes
shammeer-s May 20, 2025
93e849f
Type hints fix
shammeer-s May 22, 2025
6eca985
Merge branch 'main' into visualization
timmens Jul 1, 2025
917821f
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s Jul 1, 2025
02bf915
Merge branch 'optimagic-dev:main' into visualization
shammeer-s Jul 23, 2025
4722d11
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s Jul 23, 2025
c87cb2b
Merge branch 'main' into visualization
timmens Jul 24, 2025
99d8b40
Merge branch 'main' into visualization
timmens Jul 28, 2025
3bfb1ea
Refactor slice_plot_3d for improved documentation and functionality
shammeer-s Aug 6, 2025
d8fba71
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s Aug 6, 2025
dfeebfb
Merge branch 'optimagic-dev:main' into visualization
shammeer-s Aug 6, 2025
dfd356c
Merge remote-tracking branch 'origin/visualization' into visualization
shammeer-s Aug 6, 2025
098fefd
Removing debugging statements
shammeer-s Aug 6, 2025
b1bfbb6
Remove unnecessary print statements from test_slice_plot_3d
shammeer-s Aug 7, 2025
8452d8a
Merge branch 'main' into visualization
timmens Sep 16, 2025
4a33202
Merge branch 'main' into visualization
timmens Sep 17, 2025
10583eb
Add exclusion for 'how_to_slice_plot_3d.ipynb' in documentation build
shammeer-s Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions src/optimagic/visualization/plot_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# type: ignore

import warnings
from functools import partial

import numpy as np
import pandas as pd
from pybaum import tree_just_flatten

from optimagic import deprecations
from optimagic.batch_evaluators import process_batch_evaluator
from optimagic.deprecations import replace_and_warn_about_deprecated_bounds
from optimagic.optimization.fun_value import (
convert_fun_output_to_function_value,
enforce_return_type,
)
from optimagic.parameters.bounds import pre_process_bounds
from optimagic.parameters.tree_registry import get_registry
from optimagic.shared.process_user_function import infer_aggregation_level
from optimagic.typing import AggregationLevel


def evaluate_func(params, func, func_kwargs):
"""Evaluate a user-defined function, handling deprecated dictionary output.

Args:
params: Input parameters for the function.
func: The user-defined objective function.
func_kwargs: Optional dictionary of keyword arguments to pass to the function

Returns:
A tuple of (possibly wrapped) function and its evaluated output.

"""
if func_kwargs:
func = partial(func, **func_kwargs)

func_eval = func(params)

if deprecations.is_dict_output(func_eval):
warnings.warn(
"Functions that return dictionaries are deprecated and will "
"raise an error in future versions.",
FutureWarning,
)
func_eval = deprecations.convert_dict_to_function_value(func_eval)
func = deprecations.replace_dict_output(func)

problem_type = (
deprecations.infer_problem_type_from_dict_output(func_eval)
if deprecations.is_dict_output(func_eval)
else infer_aggregation_level(func)
)
func_eval = convert_fun_output_to_function_value(func_eval, problem_type)
func = enforce_return_type(problem_type)(func)
return func, func_eval


def process_bounds(bounds, lower_bounds, upper_bounds):
"""Process parameter bounds, replacing deprecated formats if necessary.

Args:
bounds: Bound object or structure.
lower_bounds: Deprecated lower bounds.
upper_bounds: Deprecated upper bounds.

Returns:
Processed and validated bounds.

"""
bounds = replace_and_warn_about_deprecated_bounds(
bounds=bounds, lower_bounds=lower_bounds, upper_bounds=upper_bounds
)
return pre_process_bounds(bounds)


def select_parameter_indices(converter, selector, n_params):
"""Select parameter indices using a selector function, or select all by default.

Args:
converter: Parameter converter.
selector: Callable to select specific parameters.
n_params: Total number of parameters.

Returns:
Array of selected parameter indices.

"""
if selector is None:
return np.arange(n_params, dtype=int)

helper = converter.params_from_internal(np.arange(n_params))
registry = get_registry(extended=True)
return np.array(tree_just_flatten(selector(helper), registry=registry), dtype=int)


def generate_grid_data(internal_params, selected, n_gridpoints):
"""Generate a grid of parameter values based on selection.

Args:
internal_params: Internal representation of parameters.
selected: List of indices for parameters to vary.
n_gridpoints: Number of values to generate per parameter.

Returns:
DataFrame containing the grid of parameter values.

"""
metadata = {
name: (
np.linspace(
internal_params.lower_bounds[pos],
internal_params.upper_bounds[pos],
n_gridpoints,
)
if pos in selected
else internal_params.values[pos]
)
for pos, name in enumerate(internal_params.names)
}
return pd.DataFrame(metadata)


def evaluate_function_values(func, evaluation_points, batch_evaluator, n_cores):
"""Evaluate function at multiple points using a batch evaluation strategy.

Args:
func: The function to evaluate.
evaluation_points: List of input parameter values.
batch_evaluator: Function or callable that evaluates in batch.
n_cores: Number of CPU cores for parallelism.

Returns:
List of evaluated function values or NaNs for failed evaluations.

"""
batch_evaluator = process_batch_evaluator(batch_evaluator)
results = batch_evaluator(
func=func,
arguments=evaluation_points,
error_handling="continue",
n_cores=n_cores,
)
return [
float("nan")
if isinstance(val, str)
else val.internal_value(AggregationLevel.SCALAR)
for val in results
]


def generate_eval_points(grid, params, param_names, fixed_vars, converter, projection):
"""Generate evaluation points based on a grid of selected parameters and fixed
variables.

This function supports two modes:
- If `projection` is not "slice",
a full 2D meshgrid of points is generated for the two selected parameters.
- If `projection` is "slice",
only the selected parameters are varied individually.

Args:
grid: DataFrame of generated parameter values.
params: Internal parameter structure.
param_names: Names of parameters to vary.
fixed_vars: Dictionary of fixed parameter values.
converter: Converter object to map to internal parameter format.
projection: Projection mode ("contour", "3d", or "slice").

Returns:
If projection is not "slice":
Tuple of meshgrid arrays (X, Y) and list of evaluation points.
If projection is "slice":
Tuple of selected input values (X) and list of evaluation points.

"""
evaluation_points = []

if projection != "slice":
x_vals = grid[param_names[0]].to_numpy()
y_vals = grid[param_names[1]].to_numpy()
x, y = np.meshgrid(x_vals, y_vals)

for a, b in zip(x.ravel(), y.ravel(), strict=False):
point_dict = {param_names[0]: a, param_names[1]: b, **fixed_vars}
internal_values = np.array(list(point_dict.values()))
evaluation_points.append(converter.params_from_internal(internal_values))

return x, y, evaluation_points

else:
x = grid[param_names].to_numpy()
for param_value in x:
point_dict = (
{**fixed_vars, param_names: param_value}
if isinstance(param_names, str)
else {
**fixed_vars,
**dict(zip(param_names, param_value, strict=False)),
}
)

internal_values = np.array(
[
point_dict.get(name, params.values[params.names.index(name)])
for name in params.names
]
)
evaluation_points.append(converter.params_from_internal(internal_values))
return x, evaluation_points
Loading