Skip to content

Commit d0c6856

Browse files
committed
implement BackendWrapper registry for unified usage across backends.
1 parent acd7852 commit d0c6856

File tree

4 files changed

+156
-87
lines changed

4 files changed

+156
-87
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ module = [
291291
"optimagic.shared.process_user_function",
292292

293293
"optimagic.visualization",
294+
"optimagic.visualization.backends",
294295
"optimagic.visualization.convergence_plot",
295296
"optimagic.visualization.deviation_plot",
296297
"optimagic.visualization.history_plots",
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import abc
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
import matplotlib.pyplot as plt
6+
import plotly.graph_objects as go
7+
8+
9+
@dataclass(frozen=True)
10+
class PlotConfig:
11+
template: str
12+
plotly_legend: dict[str, Any]
13+
matplotlib_legend: dict[str, Any]
14+
15+
16+
class BackendWrapper(abc.ABC):
17+
def __init__(self, plot_config):
18+
self.plot_config = plot_config
19+
20+
@abc.abstractmethod
21+
def create_figure(self):
22+
pass
23+
24+
@abc.abstractmethod
25+
def lineplot(self, **kwargs):
26+
pass
27+
28+
@abc.abstractmethod
29+
def post_plot(self, **kwargs):
30+
pass
31+
32+
@abc.abstractmethod
33+
def return_obj(self):
34+
pass
35+
36+
37+
class BackendRegistry:
38+
_registry: dict[str, BackendWrapper] = {}
39+
40+
@classmethod
41+
def register(cls, backend_name):
42+
def decorator(backend_wrapper):
43+
cls._registry[backend_name] = backend_wrapper
44+
return backend_wrapper
45+
46+
return decorator
47+
48+
@classmethod
49+
def get_backend_wrapper(cls, backend_name):
50+
if backend_name not in cls._registry:
51+
raise ValueError(
52+
f"Backend '{backend_name}' is not supported. "
53+
f"Supported backends are: {', '.join(cls._registry.keys())}."
54+
)
55+
return cls._registry.get(backend_name)
56+
57+
58+
@BackendRegistry.register("plotly")
59+
class PlotlyBackend(BackendWrapper):
60+
def __init__(self, plot_config):
61+
super().__init__(plot_config)
62+
self.fig = self.create_figure()
63+
64+
def create_figure(self):
65+
fig = go.Figure()
66+
return fig
67+
68+
def lineplot(self, *, x, y, color, name=None, plotly_scatter_kws=None, **kwargs):
69+
if plotly_scatter_kws is None:
70+
plotly_scatter_kws = {}
71+
72+
trace = go.Scatter(
73+
x=x, y=y, mode="lines", line_color=color, name=name, **plotly_scatter_kws
74+
)
75+
self.fig.add_trace(trace)
76+
77+
def post_plot(self, *, xlabel=None, ylabel=None, **kwargs):
78+
self.fig.update_layout(
79+
template=self.plot_config.template,
80+
xaxis_title_text=xlabel,
81+
yaxis_title_text=ylabel,
82+
legend=self.plot_config.plotly_legend,
83+
)
84+
85+
def return_obj(self):
86+
return self.fig
87+
88+
89+
@BackendRegistry.register("matplotlib")
90+
class MatplotlibBackend(BackendWrapper):
91+
def __init__(self, plot_config):
92+
super().__init__(plot_config)
93+
self.fig, self.ax = self.create_figure()
94+
95+
def create_figure(self):
96+
plt.style.use(self.plot_config.template)
97+
fig, ax = plt.subplots()
98+
return fig, ax
99+
100+
def lineplot(self, *, x, y, color, name=None, **kwargs):
101+
self.ax.plot(x, y, color=color, label=name)
102+
103+
def post_plot(self, *, xlabel=None, ylabel=None, **kwargs):
104+
self.ax.set(xlabel=xlabel, ylabel=ylabel)
105+
self.ax.legend(**self.plot_config.matplotlib_legend)
106+
107+
def return_obj(self):
108+
return self.fig

src/optimagic/visualization/history_plots.py

Lines changed: 25 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
import inspect
2-
import itertools
32
from pathlib import Path
43
from typing import Any
54

6-
import matplotlib as mpl
7-
import matplotlib.pyplot as plt
85
import numpy as np
96
import plotly.graph_objects as go
107
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
118

12-
from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE
9+
from optimagic.config import PLOTLY_TEMPLATE
1310
from optimagic.logging.logger import LogReader, SQLiteLogOptions
1411
from optimagic.optimization.algorithm import Algorithm
1512
from optimagic.optimization.history import History
1613
from optimagic.optimization.optimize_result import OptimizeResult
1714
from optimagic.parameters.tree_registry import get_registry
1815
from optimagic.typing import Direction
16+
from optimagic.visualization.backends import BackendRegistry, PlotConfig
17+
from optimagic.visualization.plotting_utilities import get_palette, get_template
1918

2019

2120
def criterion_plot(
2221
results,
2322
names=None,
24-
backend="plotly",
2523
max_evaluations=None,
24+
backend="plotly",
2625
template=None,
2726
palette=None,
2827
stack_multistart=False,
@@ -36,8 +35,8 @@ def criterion_plot(
3635
dict of) optimization results with collected history. If dict, then the
3736
key is used as the name in a legend.
3837
names (Union[List[str], str]): Names corresponding to res or entries in res.
39-
backend (str): The backend to use for plotting. Default is "plotly".
4038
max_evaluations (int): Clip the criterion history after that many entries.
39+
backend (str): The backend to use for plotting. Default is "plotly".
4140
template (str): The template for the figure. Default is "plotly_white".
4241
palette (Union[List[str], str]): The coloring palette for traces. Default is
4342
"qualitative.Plotly".
@@ -50,7 +49,7 @@ def criterion_plot(
5049
optimization are visualized. Default is False.
5150
5251
Returns:
53-
Figure object returned by the chosen backend.
52+
Native figure object returned by the chosen backend.
5453
5554
"""
5655
# ==================================================================================
@@ -59,16 +58,8 @@ def criterion_plot(
5958

6059
results = _harmonize_inputs_to_dict(results, names)
6160

62-
if template is None:
63-
template = PLOT_DEFAULTS[backend]["template"]
64-
if palette is None:
65-
palette = PLOT_DEFAULTS[backend]["palette"]
66-
67-
if isinstance(palette, mpl.colors.Colormap):
68-
palette = [palette(i) for i in range(palette.N)]
69-
if not isinstance(palette, list):
70-
palette = [palette]
71-
palette = itertools.cycle(palette)
61+
template = get_template(backend, template)
62+
palette = get_palette(backend, palette)
7263

7364
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
7465

@@ -98,15 +89,22 @@ def criterion_plot(
9889
# Create figure
9990
# ==================================================================================
10091

101-
fig, plot_func, label_func = _get_plot_backend(backend)
102-
103-
plot_multistart = (
104-
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
92+
plot_config = PlotConfig(
93+
template=template,
94+
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
95+
matplotlib_legend={"loc": "upper right"},
10596
)
10697

98+
_backend_wrapper = BackendRegistry.get_backend_wrapper(backend)
99+
fig = _backend_wrapper(plot_config)
100+
107101
# ==================================================================================
108102
# Plot multistart paths
109103

104+
plot_multistart = (
105+
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
106+
)
107+
110108
if plot_multistart:
111109
scatter_kws = {
112110
"connectgaps": True,
@@ -119,8 +117,7 @@ def criterion_plot(
119117
if max_evaluations is not None and len(history) > max_evaluations:
120118
history = history[:max_evaluations]
121119

122-
plot_func(
123-
fig,
120+
fig.lineplot(
124121
x=np.arange(len(history)),
125122
y=history,
126123
name=None,
@@ -144,30 +141,23 @@ def criterion_plot(
144141

145142
scatter_kws = {
146143
"connectgaps": True,
147-
"showlegend": not plot_multistart,
144+
"showlegend": True,
148145
}
149146

150-
_color = next(palette)
151-
152-
plot_func(
153-
fig,
147+
fig.lineplot(
154148
x=np.arange(len(history)),
155149
y=history,
156150
name="best result" if plot_multistart else _data["name"],
157-
color=_color,
151+
color=next(palette),
158152
plotly_scatter_kws=scatter_kws,
159153
)
160154

161-
label_func(
162-
fig,
163-
template=template,
155+
fig.post_plot(
164156
xlabel="No. of criterion evaluations",
165157
ylabel="Criterion value",
166-
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
167-
matplotlib_legend={"loc": "upper right"},
168158
)
169159

170-
return fig
160+
return fig.return_obj()
171161

172162

173163
def _harmonize_inputs_to_dict(results, names):
@@ -461,54 +451,3 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
461451
task=len(stacked["criterion"]) * [None],
462452
batches=list(range(len(stacked["criterion"]))),
463453
)
464-
465-
466-
def _get_plot_backend(backend):
467-
backends = {
468-
"plotly": (
469-
go.Figure(),
470-
_plot_plotly,
471-
_label_plotly,
472-
),
473-
"matplotlib": (
474-
plt.subplots()[1],
475-
_plot_matplotlib,
476-
_label_matplotlib,
477-
),
478-
}
479-
480-
if backend not in backends:
481-
msg = (
482-
f"Backend '{backend}' is not supported. "
483-
f"Supported backends are: {', '.join(backends.keys())}."
484-
)
485-
raise ValueError(msg)
486-
487-
return backends[backend]
488-
489-
490-
def _plot_plotly(fig, *, x, y, name, color, plotly_scatter_kws, **kwargs):
491-
trace = go.Scatter(
492-
x=x, y=y, mode="lines", name=name, line_color=color, **plotly_scatter_kws
493-
)
494-
fig.add_trace(trace)
495-
return fig
496-
497-
498-
def _label_plotly(fig, *, template, xlabel, ylabel, plotly_legend, **kwargs):
499-
fig.update_layout(
500-
template=template,
501-
xaxis_title_text=xlabel,
502-
yaxis_title_text=ylabel,
503-
legend=plotly_legend,
504-
)
505-
506-
507-
def _plot_matplotlib(ax, *, x, y, name, color, **kwargs):
508-
ax.plot(x, y, label=name, color=color)
509-
return ax
510-
511-
512-
def _label_matplotlib(ax, *, xlabel, ylabel, matplotlib_legend, **kwargs):
513-
ax.set(xlabel=xlabel, ylabel=ylabel)
514-
ax.legend(**matplotlib_legend)

src/optimagic/visualization/plotting_utilities.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import itertools
22
from copy import deepcopy
33

4+
import matplotlib as mpl
45
import numpy as np
56
import plotly.graph_objects as go
67
from plotly.subplots import make_subplots
78

8-
from optimagic.config import PLOTLY_TEMPLATE
9+
from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE
910

1011

1112
def combine_plots(
@@ -328,3 +329,23 @@ def get_layout_kwargs(layout_kwargs, legend_kwargs, title_kwargs, template, show
328329
if layout_kwargs:
329330
default_kwargs.update(layout_kwargs)
330331
return default_kwargs
332+
333+
334+
def get_template(backend, template):
335+
if template is None:
336+
template = PLOT_DEFAULTS[backend]["template"]
337+
338+
return template
339+
340+
341+
def get_palette(backend, palette):
342+
if palette is None:
343+
palette = PLOT_DEFAULTS[backend]["palette"]
344+
345+
if isinstance(palette, mpl.colors.Colormap):
346+
palette = [palette(i) for i in range(palette.N)]
347+
if not isinstance(palette, list):
348+
palette = [palette]
349+
palette = itertools.cycle(palette)
350+
351+
return palette

0 commit comments

Comments
 (0)