Skip to content

Commit acd7852

Browse files
committed
implementation of matplotlib backend for criterion plot
1 parent fe0dcf7 commit acd7852

File tree

8 files changed

+101
-29
lines changed

8 files changed

+101
-29
lines changed

.tools/envs/testenv-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- numpy >= 2 # run, tests
2020
- pandas # run, tests
2121
- plotly<6.0.0 # run, tests
22+
- matplotlib # run, tests
2223
- pybaum>=0.1.2 # run, tests
2324
- scipy>=1.2.1 # run, tests
2425
- sqlalchemy # run, tests

.tools/envs/testenv-numpy.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

.tools/envs/testenv-others.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- numpy >= 2 # run, tests
1818
- pandas # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

.tools/envs/testenv-pandas.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- numpy >= 2 # run, tests
2222
- pandas # run, tests
2323
- plotly<6.0.0 # run, tests
24+
- matplotlib # run, tests
2425
- pybaum>=0.1.2 # run, tests
2526
- scipy>=1.2.1 # run, tests
2627
- sqlalchemy # run, tests

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"numpy",
1212
"pandas",
1313
"plotly<6.0.0",
14+
"matplotlib",
1415
"pybaum>=0.1.2",
1516
"scipy>=1.2.1",
1617
"sqlalchemy>=1.3",
@@ -347,6 +348,8 @@ module = [
347348
"plotly.graph_objects",
348349
"plotly.express",
349350
"plotly.subplots",
351+
"matplotlib",
352+
"matplotlib.pyplot",
350353
"cyipopt",
351354
"nlopt",
352355
"bokeh",

src/optimagic/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import matplotlib as mpl
34
import pandas as pd
45
import plotly.express as px
56
from packaging import version
@@ -10,6 +11,14 @@
1011
PLOTLY_TEMPLATE = "simple_white"
1112
PLOTLY_PALETTE = px.colors.qualitative.Set2
1213

14+
PLOT_DEFAULTS = {
15+
"plotly": {"template": "simple_white", "palette": px.colors.qualitative.Set2},
16+
"matplotlib": {
17+
"template": "default",
18+
"palette": mpl.colormaps["Set2"],
19+
},
20+
}
21+
1322
DEFAULT_N_CORES = 1
1423

1524
CRITERION_PENALTY_SLOPE = 0.1

src/optimagic/visualization/history_plots.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from pathlib import Path
44
from typing import Any
55

6+
import matplotlib as mpl
7+
import matplotlib.pyplot as plt
68
import numpy as np
79
import plotly.graph_objects as go
810
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
911

10-
from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
12+
from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE
1113
from optimagic.logging.logger import LogReader, SQLiteLogOptions
1214
from optimagic.optimization.algorithm import Algorithm
1315
from optimagic.optimization.history import History
@@ -19,9 +21,10 @@
1921
def criterion_plot(
2022
results,
2123
names=None,
24+
backend="plotly",
2225
max_evaluations=None,
23-
template=PLOTLY_TEMPLATE,
24-
palette=PLOTLY_PALETTE,
26+
template=None,
27+
palette=None,
2528
stack_multistart=False,
2629
monotone=False,
2730
show_exploration=False,
@@ -33,6 +36,7 @@ def criterion_plot(
3336
dict of) optimization results with collected history. If dict, then the
3437
key is used as the name in a legend.
3538
names (Union[List[str], str]): Names corresponding to res or entries in res.
39+
backend (str): The backend to use for plotting. Default is "plotly".
3640
max_evaluations (int): Clip the criterion history after that many entries.
3741
template (str): The template for the figure. Default is "plotly_white".
3842
palette (Union[List[str], str]): The coloring palette for traces. Default is
@@ -46,7 +50,7 @@ def criterion_plot(
4650
optimization are visualized. Default is False.
4751
4852
Returns:
49-
plotly.graph_objs._figure.Figure: The figure.
53+
Figure object returned by the chosen backend.
5054
5155
"""
5256
# ==================================================================================
@@ -55,6 +59,13 @@ def criterion_plot(
5559

5660
results = _harmonize_inputs_to_dict(results, names)
5761

62+
if template is None:
63+
template = PLOT_DEFAULTS[backend]["template"]
64+
if palette is None:
65+
palette = PLOT_DEFAULTS[backend]["palette"]
66+
67+
if isinstance(palette, mpl.colors.Colormap):
68+
palette = [palette(i) for i in range(palette.N)]
5869
if not isinstance(palette, list):
5970
palette = [palette]
6071
palette = itertools.cycle(palette)
@@ -87,7 +98,7 @@ def criterion_plot(
8798
# Create figure
8899
# ==================================================================================
89100

90-
fig = go.Figure()
101+
fig, plot_func, label_func = _get_plot_backend(backend)
91102

92103
plot_multistart = (
93104
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
@@ -102,21 +113,20 @@ def criterion_plot(
102113
"showlegend": False,
103114
}
104115

105-
for i, local_history in enumerate(data[0]["local_histories"]):
116+
for local_history in data[0]["local_histories"]:
106117
history = getattr(local_history, fun_or_monotone_fun)
107118

108119
if max_evaluations is not None and len(history) > max_evaluations:
109120
history = history[:max_evaluations]
110121

111-
trace = go.Scatter(
122+
plot_func(
123+
fig,
112124
x=np.arange(len(history)),
113125
y=history,
114-
mode="lines",
115-
name=str(i),
116-
line_color="#bab0ac",
117-
**scatter_kws,
126+
name=None,
127+
color="#bab0ac",
128+
plotly_scatter_kws=scatter_kws,
118129
)
119-
fig.add_trace(trace)
120130

121131
# ==================================================================================
122132
# Plot main optimization objects
@@ -138,31 +148,25 @@ def criterion_plot(
138148
}
139149

140150
_color = next(palette)
141-
if not isinstance(_color, str):
142-
msg = "highlight_palette needs to be a string or list of strings, but its "
143-
f"entry is of type {type(_color)}."
144-
raise TypeError(msg)
145151

146-
line_kws = {
147-
"color": _color,
148-
}
149-
150-
trace = go.Scatter(
152+
plot_func(
153+
fig,
151154
x=np.arange(len(history)),
152155
y=history,
153-
mode="lines",
154156
name="best result" if plot_multistart else _data["name"],
155-
line=line_kws,
156-
**scatter_kws,
157+
color=_color,
158+
plotly_scatter_kws=scatter_kws,
157159
)
158-
fig.add_trace(trace)
159160

160-
fig.update_layout(
161+
label_func(
162+
fig,
161163
template=template,
162-
xaxis_title_text="No. of criterion evaluations",
163-
yaxis_title_text="Criterion value",
164-
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
164+
xlabel="No. of criterion evaluations",
165+
ylabel="Criterion value",
166+
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
167+
matplotlib_legend={"loc": "upper right"},
165168
)
169+
166170
return fig
167171

168172

@@ -457,3 +461,54 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
457461
task=len(stacked["criterion"]) * [None],
458462
batches=list(range(len(stacked["criterion"]))),
459463
)
464+
465+
466+
def _get_plot_backend(backend):
467+
backends = {
468+
"plotly": (
469+
go.Figure(),
470+
_plot_plotly,
471+
_label_plotly,
472+
),
473+
"matplotlib": (
474+
plt.subplots()[1],
475+
_plot_matplotlib,
476+
_label_matplotlib,
477+
),
478+
}
479+
480+
if backend not in backends:
481+
msg = (
482+
f"Backend '{backend}' is not supported. "
483+
f"Supported backends are: {', '.join(backends.keys())}."
484+
)
485+
raise ValueError(msg)
486+
487+
return backends[backend]
488+
489+
490+
def _plot_plotly(fig, *, x, y, name, color, plotly_scatter_kws, **kwargs):
491+
trace = go.Scatter(
492+
x=x, y=y, mode="lines", name=name, line_color=color, **plotly_scatter_kws
493+
)
494+
fig.add_trace(trace)
495+
return fig
496+
497+
498+
def _label_plotly(fig, *, template, xlabel, ylabel, plotly_legend, **kwargs):
499+
fig.update_layout(
500+
template=template,
501+
xaxis_title_text=xlabel,
502+
yaxis_title_text=ylabel,
503+
legend=plotly_legend,
504+
)
505+
506+
507+
def _plot_matplotlib(ax, *, x, y, name, color, **kwargs):
508+
ax.plot(x, y, label=name, color=color)
509+
return ax
510+
511+
512+
def _label_matplotlib(ax, *, xlabel, ylabel, matplotlib_legend, **kwargs):
513+
ax.set(xlabel=xlabel, ylabel=ylabel)
514+
ax.legend(**matplotlib_legend)

0 commit comments

Comments
 (0)