Skip to content

Commit f780bf9

Browse files
committed
barebones implementation of matplotlib backend for criterion plot
1 parent fe0dcf7 commit f780bf9

File tree

8 files changed

+100
-29
lines changed

8 files changed

+100
-29
lines changed

.tools/envs/testenv-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- numpy >= 2 # run, tests
2020
- pandas # run, tests
2121
- plotly<6.0.0 # run, tests
22+
- matplotlib # run, tests
2223
- pybaum>=0.1.2 # run, tests
2324
- scipy>=1.2.1 # run, tests
2425
- sqlalchemy # run, tests

.tools/envs/testenv-numpy.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

.tools/envs/testenv-others.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- numpy >= 2 # run, tests
1818
- pandas # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

.tools/envs/testenv-pandas.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly<6.0.0 # run, tests
20+
- matplotlib # run, tests
2021
- pybaum>=0.1.2 # run, tests
2122
- scipy>=1.2.1 # run, tests
2223
- sqlalchemy # run, tests

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- numpy >= 2 # run, tests
2222
- pandas # run, tests
2323
- plotly<6.0.0 # run, tests
24+
- matplotlib # run, tests
2425
- pybaum>=0.1.2 # run, tests
2526
- scipy>=1.2.1 # run, tests
2627
- sqlalchemy # run, tests

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"numpy",
1212
"pandas",
1313
"plotly<6.0.0",
14+
"matplotlib",
1415
"pybaum>=0.1.2",
1516
"scipy>=1.2.1",
1617
"sqlalchemy>=1.3",
@@ -347,6 +348,8 @@ module = [
347348
"plotly.graph_objects",
348349
"plotly.express",
349350
"plotly.subplots",
351+
"matplotlib",
352+
"matplotlib.pyplot",
350353
"cyipopt",
351354
"nlopt",
352355
"bokeh",

src/optimagic/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import matplotlib as mpl
34
import pandas as pd
45
import plotly.express as px
56
from packaging import version
@@ -10,6 +11,14 @@
1011
PLOTLY_TEMPLATE = "simple_white"
1112
PLOTLY_PALETTE = px.colors.qualitative.Set2
1213

14+
PLOT_DEFAULTS = {
15+
"plotly": {"template": "simple_white", "palette": px.colors.qualitative.Set2},
16+
"matplotlib": {
17+
"template": "default",
18+
"palette": [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)],
19+
},
20+
}
21+
1322
DEFAULT_N_CORES = 1
1423

1524
CRITERION_PENALTY_SLOPE = 0.1

src/optimagic/visualization/history_plots.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from pathlib import Path
44
from typing import Any
55

6+
import matplotlib.pyplot as plt
67
import numpy as np
78
import plotly.graph_objects as go
89
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
910

10-
from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
11+
from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE
1112
from optimagic.logging.logger import LogReader, SQLiteLogOptions
1213
from optimagic.optimization.algorithm import Algorithm
1314
from optimagic.optimization.history import History
@@ -19,9 +20,10 @@
1920
def criterion_plot(
2021
results,
2122
names=None,
23+
backend="plotly",
2224
max_evaluations=None,
23-
template=PLOTLY_TEMPLATE,
24-
palette=PLOTLY_PALETTE,
25+
template=None,
26+
palette=None,
2527
stack_multistart=False,
2628
monotone=False,
2729
show_exploration=False,
@@ -33,6 +35,7 @@ def criterion_plot(
3335
dict of) optimization results with collected history. If dict, then the
3436
key is used as the name in a legend.
3537
names (Union[List[str], str]): Names corresponding to res or entries in res.
38+
backend (str): The backend to use for plotting. Default is "plotly".
3639
max_evaluations (int): Clip the criterion history after that many entries.
3740
template (str): The template for the figure. Default is "plotly_white".
3841
palette (Union[List[str], str]): The coloring palette for traces. Default is
@@ -46,7 +49,7 @@ def criterion_plot(
4649
optimization are visualized. Default is False.
4750
4851
Returns:
49-
plotly.graph_objs._figure.Figure: The figure.
52+
Figure object returned by the chosen backend.
5053
5154
"""
5255
# ==================================================================================
@@ -55,6 +58,11 @@ def criterion_plot(
5558

5659
results = _harmonize_inputs_to_dict(results, names)
5760

61+
if template is None:
62+
template = PLOT_DEFAULTS[backend]["template"]
63+
if palette is None:
64+
palette = PLOT_DEFAULTS[backend]["palette"]
65+
5866
if not isinstance(palette, list):
5967
palette = [palette]
6068
palette = itertools.cycle(palette)
@@ -87,7 +95,7 @@ def criterion_plot(
8795
# Create figure
8896
# ==================================================================================
8997

90-
fig = go.Figure()
98+
fig, plot_func, label_func = _get_plot_backend(backend)
9199

92100
plot_multistart = (
93101
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
@@ -102,21 +110,20 @@ def criterion_plot(
102110
"showlegend": False,
103111
}
104112

105-
for i, local_history in enumerate(data[0]["local_histories"]):
113+
for local_history in data[0]["local_histories"]:
106114
history = getattr(local_history, fun_or_monotone_fun)
107115

108116
if max_evaluations is not None and len(history) > max_evaluations:
109117
history = history[:max_evaluations]
110118

111-
trace = go.Scatter(
119+
plot_func(
120+
fig,
112121
x=np.arange(len(history)),
113122
y=history,
114-
mode="lines",
115-
name=str(i),
116-
line_color="#bab0ac",
117-
**scatter_kws,
123+
name=None,
124+
color="#bab0ac",
125+
scatter_kws=scatter_kws,
118126
)
119-
fig.add_trace(trace)
120127

121128
# ==================================================================================
122129
# Plot main optimization objects
@@ -138,31 +145,26 @@ def criterion_plot(
138145
}
139146

140147
_color = next(palette)
141-
if not isinstance(_color, str):
142-
msg = "highlight_palette needs to be a string or list of strings, but its "
143-
f"entry is of type {type(_color)}."
144-
raise TypeError(msg)
145148

146-
line_kws = {
147-
"color": _color,
148-
}
149-
150-
trace = go.Scatter(
149+
plot_func(
150+
fig,
151151
x=np.arange(len(history)),
152152
y=history,
153-
mode="lines",
154153
name="best result" if plot_multistart else _data["name"],
155-
line=line_kws,
156-
**scatter_kws,
154+
color=_color,
155+
mode="lines",
156+
scatter_kws=scatter_kws,
157157
)
158-
fig.add_trace(trace)
159158

160-
fig.update_layout(
159+
label_func(
160+
fig,
161161
template=template,
162-
xaxis_title_text="No. of criterion evaluations",
163-
yaxis_title_text="Criterion value",
164-
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
162+
xlabel="No. of criterion evaluations",
163+
ylabel="Criterion value",
164+
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
165+
matplotlib_legend="upper right",
165166
)
167+
166168
return fig
167169

168170

@@ -457,3 +459,55 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
457459
task=len(stacked["criterion"]) * [None],
458460
batches=list(range(len(stacked["criterion"]))),
459461
)
462+
463+
464+
def _get_plot_backend(backend):
465+
backends = {
466+
"plotly": (
467+
go.Figure(),
468+
_plot_plotly,
469+
_label_plotly,
470+
),
471+
"matplotlib": (
472+
plt,
473+
_plot_matplotlib,
474+
_label_matplotlib,
475+
),
476+
}
477+
478+
if backend not in backends:
479+
msg = (
480+
f"Backend '{backend}' is not supported. "
481+
f"Supported backends are: {', '.join(backends.keys())}."
482+
)
483+
raise ValueError(msg)
484+
485+
return backends[backend]
486+
487+
488+
def _plot_plotly(fig, *, x, y, name, color, scatter_kws, **kwargs):
489+
trace = go.Scatter(
490+
x=x, y=y, mode="lines", name=name, line_color=color, **scatter_kws
491+
)
492+
fig.add_trace(trace)
493+
return fig
494+
495+
496+
def _label_plotly(fig, *, template, xlabel, ylabel, plotly_legend, **kwargs):
497+
fig.update_layout(
498+
template=template,
499+
xaxis_title_text=xlabel,
500+
yaxis_title_text=ylabel,
501+
legend=plotly_legend,
502+
)
503+
504+
505+
def _plot_matplotlib(plt, *, x, y, name, color, **kwargs):
506+
plt.plot(x, y, label=name, color=color)
507+
return plt
508+
509+
510+
def _label_matplotlib(plt, *, xlabel, ylabel, matplotlib_legend, **kwargs):
511+
plt.gca().set_xlabel(xlabel)
512+
plt.gca().set_ylabel(ylabel)
513+
plt.legend(loc=matplotlib_legend)

0 commit comments

Comments
 (0)