Skip to content
Merged
1 change: 1 addition & 0 deletions .tools/envs/testenv-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-nevergrad.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-others.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-pandas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-plotly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- joblib # run, tests
- numpy >= 2 # run, tests
- pandas # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"numpy",
"pandas",
"plotly",
"matplotlib",
"pybaum>=0.1.2",
"scipy>=1.2.1",
"sqlalchemy>=1.3",
Expand Down Expand Up @@ -290,6 +291,7 @@ module = [

"optimagic.visualization",
"optimagic.visualization.convergence_plot",
"optimagic.visualization.backends",
"optimagic.visualization.deviation_plot",
"optimagic.visualization.history_plots",
"optimagic.visualization.plotting_utilities",
Expand Down Expand Up @@ -346,6 +348,8 @@ module = [
"plotly.graph_objects",
"plotly.express",
"plotly.subplots",
"matplotlib",
"matplotlib.pyplot",
"cyipopt",
"nlopt",
"bokeh",
Expand Down
107 changes: 107 additions & 0 deletions src/optimagic/visualization/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import abc
from typing import Any

import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

from optimagic.visualization.plotting_utilities import LineData


class PlotBackend(abc.ABC):
default_template: str
default_palette: list

@abc.abstractmethod
def __init__(self, template: str | None):
if template is None:
template = self.default_template

self.template = template
self.figure: Any = None

@abc.abstractmethod
def add_lines(self, lines: list[LineData]) -> None:
pass

@abc.abstractmethod
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
pass

@abc.abstractmethod
def set_legend_props(self, legend_props: dict[str, Any]) -> None:
pass


class PlotlyBackend(PlotBackend):
default_template: str = "simple_white"
default_palette: list = px.colors.qualitative.Set2

def __init__(self, template: str | None):
super().__init__(template)
self._fig = go.Figure()
self._fig.update_layout(template=self.template)
self.figure = self._fig

def add_lines(self, lines: list[LineData]) -> None:
for line in lines:
trace = go.Scatter(
x=line.x,
y=line.y,
name=line.name,
mode="lines",
line_color=line.color,
showlegend=line.show_in_legend,
connectgaps=True,
)
self._fig.add_trace(trace)

def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel)

def set_legend_props(self, legend_props: dict[str, Any]) -> None:
self._fig.update_layout(legend=legend_props)


class MatplotlibBackend(PlotBackend):
default_template: str = "default"
default_palette: list = list(mpl.colormaps["Set2"].colors)

def __init__(self, template: str | None):
super().__init__(template)
plt.style.use(self.template)
self._fig, self._ax = plt.subplots()
self.figure = self._fig

def add_lines(self, lines: list[LineData]) -> None:
for line in lines:
self._ax.plot(
line.x,
line.y,
color=line.color,
label=line.name if line.show_in_legend else None,
)

def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
self._ax.set(xlabel=xlabel, ylabel=ylabel)

def set_legend_props(self, legend_props: dict[str, Any]) -> None:
self._ax.legend(**legend_props)


PLOT_BACKEND_CLASSES = {
"plotly": PlotlyBackend,
"matplotlib": MatplotlibBackend,
}


def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
if backend_name not in PLOT_BACKEND_CLASSES:
msg = (
f"Invalid backend name '{backend_name}'. "
f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}."
)
raise ValueError(msg)

return PLOT_BACKEND_CLASSES[backend_name]
138 changes: 40 additions & 98 deletions src/optimagic/visualization/history_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,51 @@
import plotly.graph_objects as go
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
from optimagic.config import PLOTLY_TEMPLATE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import IterationHistory, PyTree
from optimagic.visualization.backends import get_plot_backend_class
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle

OptimizeResultOrPath = OptimizeResult | str | Path
BACKEND_TO_CRITERION_PLOT_LEGEND_PROPS: dict[str, dict[str, Any]] = {
"plotly": {
"yanchor": "top",
"xanchor": "right",
"y": 0.95,
"x": 0.95,
},
"matplotlib": {
"loc": "upper right",
},
}


ResultOrPath = OptimizeResult | str | Path


def criterion_plot(
results: OptimizeResultOrPath
| list[OptimizeResultOrPath]
| dict[str, OptimizeResultOrPath],
results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath],
names: list[str] | str | None = None,
max_evaluations: int | None = None,
template: str = PLOTLY_TEMPLATE,
palette: list[str] | str = PLOTLY_PALETTE,
backend: str = "plotly",
template: str | None = None,
palette: list[str] | str | None = None,
stack_multistart: bool = False,
monotone: bool = False,
show_exploration: bool = False,
) -> go.Figure:
) -> Any:
"""Plot the criterion history of an optimization.

Args:
results: A (list or dict of) optimization results with collected history.
If dict, then the key is used as the name in a legend.
names: Names corresponding to res or entries in res.
max_evaluations: Clip the criterion history after that many entries.
backend: The backend to use for plotting. Default is "plotly".
template: The template for the figure. Default is "plotly_white".
palette: The coloring palette for traces. Default is "qualitative.Set2".
stack_multistart: Whether to combine multistart histories into a single history.
Expand All @@ -51,12 +66,17 @@ def criterion_plot(
The figure object containing the criterion plot.

"""
# ==================================================================================
# Get Plot Backend class

plot_cls = get_plot_backend_class(backend)

# ==================================================================================
# Process inputs

if not isinstance(palette, list):
palette = [palette]
palette_cycle = itertools.cycle(palette)
if palette is None:
palette = plot_cls.default_palette
palette_cycle = get_palette_cycle(palette)

dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names)

Expand All @@ -78,21 +98,19 @@ def criterion_plot(
)

# ==================================================================================
# Generate the plotly figure
# Generate the figure

plot_config = PlotConfig(
template=template,
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
)
plot = plot_cls(template)

fig = _plotly_line_plot(lines + multistart_lines, plot_config)
return fig
plot.add_lines(lines + multistart_lines)
plot.set_labels(xlabel="No. of criterion evaluations", ylabel="Criterion value")
plot.set_legend_props(BACKEND_TO_CRITERION_PLOT_LEGEND_PROPS[backend])

return plot.figure


def _harmonize_inputs_to_dict(
results: OptimizeResultOrPath
| list[OptimizeResultOrPath]
| dict[str, OptimizeResultOrPath],
results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath],
names: list[str] | str | None,
) -> dict[str, OptimizeResult | str | Path]:
"""Convert all valid inputs for results and names to dict[str, OptimizeResult]."""
Expand Down Expand Up @@ -462,26 +480,6 @@ def _get_stacked_local_histories(
)


@dataclass(frozen=True)
class LineData:
"""Data of a single line.

Attributes:
x: The x-coordinates of the points.
y: The y-coordinates of the points.
color: The color of the line. Default is None.
name: The name of the line. Default is None.
show_in_legend: Whether to show the line in the legend. Default is True.

"""

x: np.ndarray
y: np.ndarray
color: str | None = None
name: str | None = None
show_in_legend: bool = True


def _extract_criterion_plot_lines(
data: list[_PlottingMultistartHistory],
max_evaluations: int | None,
Expand Down Expand Up @@ -543,69 +541,13 @@ def _extract_criterion_plot_lines(
if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]

_color = next(palette_cycle)
if not isinstance(_color, str):
msg = "highlight_palette needs to be a string or list of strings, but its "
f"entry is of type {type(_color)}."
raise TypeError(msg)

line_data = LineData(
x=np.arange(len(history)),
y=history,
color=_color,
color=next(palette_cycle),
name="best result" if plot_multistart else _data.name,
show_in_legend=not plot_multistart,
)
lines.append(line_data)

return lines, multistart_lines


@dataclass(frozen=True)
class PlotConfig:
"""Configuration settings for figure.

Attributes:
template: The template for the figure.
legend: Configuration for the legend.

"""

template: str
legend: dict[str, Any]


def _plotly_line_plot(lines: list[LineData], plot_config: PlotConfig) -> go.Figure:
"""Create a plotly line plot from the given lines and plot configuration.

Args:
lines: Data for lines to be plotted.
plot_config: Configuration for the plot.

Returns:
The figure object containing the lines.

"""

fig = go.Figure()

for line in lines:
trace = go.Scatter(
x=line.x,
y=line.y,
name=line.name,
mode="lines",
line_color=line.color,
showlegend=line.show_in_legend,
connectgaps=True,
)
fig.add_trace(trace)

fig.update_layout(
template=plot_config.template,
xaxis_title_text="No. of criterion evaluations",
yaxis_title_text="Criterion value",
legend=plot_config.legend,
)

return fig
Loading
Loading