Skip to content

Commit 0d28554

Browse files
committed
Refactor to functional approach for backend plotting. Use hardcoded default palette.
1 parent f15d982 commit 0d28554

File tree

4 files changed

+146
-130
lines changed

4 files changed

+146
-130
lines changed

src/optimagic/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
PLOTLY_TEMPLATE = "simple_white"
1212
PLOTLY_PALETTE = px.colors.qualitative.Set2
1313

14+
# The hex strings are obtained from the Plotly D3 qualitative palette.
15+
DEFAULT_PALETTE = [
16+
"#1F77B4",
17+
"#FF7F0E",
18+
"#2CA02C",
19+
"#D62728",
20+
"#9467BD",
21+
"#8C564B",
22+
"#E377C2",
23+
"#7F7F7F",
24+
"#BCBD22",
25+
"#17BECF",
26+
]
27+
1428
DEFAULT_N_CORES = 1
1529

1630
CRITERION_PENALTY_SLOPE = 0.1
Lines changed: 120 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import abc
2-
from typing import Any
1+
from typing import Any, Literal
32

4-
import plotly.express as px
53
import plotly.graph_objects as go
64

75
from optimagic.config import IS_MATPLOTLIB_INSTALLED
@@ -20,126 +18,138 @@
2018
plt.ioff()
2119

2220

23-
class PlotBackend(abc.ABC):
24-
is_available: bool
25-
default_template: str
26-
27-
@classmethod
28-
@abc.abstractmethod
29-
def get_default_palette(cls) -> list:
30-
pass
31-
32-
@abc.abstractmethod
33-
def __init__(self, template: str | None):
34-
if template is None:
35-
template = self.default_template
36-
37-
self.template = template
38-
self.figure: Any = None
39-
40-
@abc.abstractmethod
41-
def add_lines(self, lines: list[LineData]) -> None:
42-
pass
43-
44-
@abc.abstractmethod
45-
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
46-
pass
47-
48-
@abc.abstractmethod
49-
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
50-
pass
51-
52-
53-
class PlotlyBackend(PlotBackend):
54-
is_available: bool = True
55-
default_template: str = "simple_white"
56-
57-
@classmethod
58-
def get_default_palette(cls) -> list:
59-
return px.colors.qualitative.Set2
60-
61-
def __init__(self, template: str | None):
62-
super().__init__(template)
63-
self._fig = go.Figure()
64-
self._fig.update_layout(template=self.template)
65-
self.figure = self._fig
66-
67-
def add_lines(self, lines: list[LineData]) -> None:
68-
for line in lines:
69-
trace = go.Scatter(
70-
x=line.x,
71-
y=line.y,
72-
name=line.name,
73-
mode="lines",
74-
line_color=line.color,
75-
showlegend=line.show_in_legend,
76-
connectgaps=True,
77-
)
78-
self._fig.add_trace(trace)
79-
80-
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
81-
self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel)
82-
83-
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
84-
self._fig.update_layout(legend=legend_properties)
85-
86-
87-
class MatplotlibBackend(PlotBackend):
88-
is_available: bool = IS_MATPLOTLIB_INSTALLED
89-
default_template: str = "default"
90-
91-
@classmethod
92-
def get_default_palette(cls) -> list:
93-
return [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)]
94-
95-
def __init__(self, template: str | None):
96-
super().__init__(template)
97-
plt.style.use(self.template)
98-
self._fig, self._ax = plt.subplots()
99-
self.figure = self._fig
100-
101-
def add_lines(self, lines: list[LineData]) -> None:
102-
for line in lines:
103-
self._ax.plot(
104-
line.x,
105-
line.y,
106-
color=line.color,
107-
label=line.name if line.show_in_legend else None,
108-
)
21+
def _line_plot_plotly(
22+
lines: list[LineData],
23+
*,
24+
title: str | None,
25+
xlabel: str | None,
26+
ylabel: str | None,
27+
template: str | None,
28+
height: int | None,
29+
width: int | None,
30+
legend_properties: dict[str, Any] | None,
31+
) -> go.Figure:
32+
fig = go.Figure()
33+
34+
for line in lines:
35+
trace = go.Scatter(
36+
x=line.x,
37+
y=line.y,
38+
name=line.name,
39+
line_color=line.color,
40+
mode="lines",
41+
)
42+
fig.add_trace(trace)
43+
44+
fig.update_layout(
45+
title=title,
46+
xaxis_title=xlabel,
47+
yaxis_title=ylabel,
48+
template=template,
49+
height=height,
50+
width=width,
51+
)
52+
53+
if legend_properties:
54+
fig.update_layout(legend=legend_properties)
55+
56+
return fig
57+
58+
59+
def _line_plot_matplotlib(
60+
lines: list[LineData],
61+
*,
62+
title: str | None,
63+
xlabel: str | None,
64+
ylabel: str | None,
65+
template: str | None,
66+
height: int | None,
67+
width: int | None,
68+
legend_properties: dict[str, Any] | None,
69+
) -> "plt.Figure":
70+
if template is not None:
71+
plt.style.use(template)
72+
fig, ax = plt.subplots(figsize=(width, height) if width and height else None)
73+
74+
for line in lines:
75+
ax.plot(
76+
line.x,
77+
line.y,
78+
label=line.name if line.show_in_legend else None,
79+
color=line.color,
80+
)
10981

110-
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
111-
self._ax.set(xlabel=xlabel, ylabel=ylabel)
82+
ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
83+
if legend_properties:
84+
ax.legend(**legend_properties)
11285

113-
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
114-
self._ax.legend(**legend_properties)
86+
return fig
11587

11688

117-
PLOT_BACKEND_CLASSES = {
118-
"plotly": PlotlyBackend,
119-
"matplotlib": MatplotlibBackend,
89+
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION = {
90+
"plotly": (True, _line_plot_plotly),
91+
"matplotlib": (IS_MATPLOTLIB_INSTALLED, _line_plot_matplotlib),
12092
}
12193

12294

123-
def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
124-
if backend_name not in PLOT_BACKEND_CLASSES:
95+
def line_plot(
96+
lines: list[LineData],
97+
backend: Literal["plotly", "matplotlib"] = "plotly",
98+
*,
99+
title: str | None = None,
100+
xlabel: str | None = None,
101+
ylabel: str | None = None,
102+
template: str | None = None,
103+
height: int | None = None,
104+
width: int | None = None,
105+
legend_properties: dict[str, Any] | None = None,
106+
) -> Any:
107+
"""Create a line plot corresponding to the specified backend.
108+
109+
Args:
110+
lines: List of objects each containing data for a line in the plot.
111+
backend: The backend to use for plotting.
112+
title: Title of the plot.
113+
xlabel: Label for the x-axis.
114+
ylabel: Label for the y-axis.
115+
template: Backend-specific template for styling the plot.
116+
height: Height of the plot (in pixels).
117+
width: Width of the plot (in pixels).
118+
legend_properties: Backend-specific properties for the legend.
119+
120+
Returns:
121+
A figure object corresponding to the specified backend.
122+
123+
"""
124+
if backend not in BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION:
125125
msg = (
126-
f"Invalid backend name '{backend_name}'. "
127-
f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}."
126+
f"Invalid plotting backend '{backend}'. "
127+
f"Available backends: "
128+
f"{', '.join(BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION.keys())}"
128129
)
129130
raise InvalidPlottingBackendError(msg)
130131

131-
return _get_backend_if_installed(backend_name)
132-
133-
134-
def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]:
135-
plot_cls = PLOT_BACKEND_CLASSES[backend_name]
132+
_is_backend_available, _line_plot_backend_function = (
133+
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend]
134+
)
136135

137-
if not plot_cls.is_available:
136+
if not _is_backend_available:
138137
msg = (
139-
f"The '{backend_name}' backend is not installed. "
140-
f"Install the package using either 'pip install {backend_name}' or "
141-
f"'conda install -c conda-forge {backend_name}'"
138+
f"The {backend} backend is not installed. "
139+
f"Install the package using either 'pip install {backend}' or "
140+
f"'conda install -c conda-forge {backend}'"
142141
)
143142
raise NotInstalledError(msg)
144143

145-
return plot_cls
144+
fig = _line_plot_backend_function(
145+
lines,
146+
title=title,
147+
xlabel=xlabel,
148+
ylabel=ylabel,
149+
template=template,
150+
height=height,
151+
width=width,
152+
legend_properties=legend_properties,
153+
)
154+
155+
return fig

src/optimagic/visualization/history_plots.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import plotly.graph_objects as go
99
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
1010

11-
from optimagic.config import PLOTLY_TEMPLATE
11+
from optimagic.config import DEFAULT_PALETTE, PLOTLY_TEMPLATE
1212
from optimagic.logging.logger import LogReader, SQLiteLogOptions
1313
from optimagic.optimization.algorithm import Algorithm
1414
from optimagic.optimization.history import History
1515
from optimagic.optimization.optimize_result import OptimizeResult
1616
from optimagic.parameters.tree_registry import get_registry
1717
from optimagic.typing import IterationHistory, PyTree
18-
from optimagic.visualization.backends import get_plot_backend_class
18+
from optimagic.visualization.backends import line_plot
1919
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle
2020

2121
BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
@@ -40,7 +40,7 @@ def criterion_plot(
4040
max_evaluations: int | None = None,
4141
backend: Literal["plotly", "matplotlib"] = "plotly",
4242
template: str | None = None,
43-
palette: list[str] | str | None = None,
43+
palette: list[str] | str = DEFAULT_PALETTE,
4444
stack_multistart: bool = False,
4545
monotone: bool = False,
4646
show_exploration: bool = False,
@@ -66,16 +66,9 @@ def criterion_plot(
6666
The figure object containing the criterion plot.
6767
6868
"""
69-
# ==================================================================================
70-
# Get Plot Backend class
71-
72-
plot_cls = get_plot_backend_class(backend)
73-
7469
# ==================================================================================
7570
# Process inputs
7671

77-
if palette is None:
78-
palette = plot_cls.get_default_palette()
7972
palette_cycle = get_palette_cycle(palette)
8073

8174
dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names)
@@ -100,13 +93,16 @@ def criterion_plot(
10093
# ==================================================================================
10194
# Generate the figure
10295

103-
plot = plot_cls(template)
104-
105-
plot.add_lines(lines + multistart_lines)
106-
plot.set_labels(xlabel="No. of criterion evaluations", ylabel="Criterion value")
107-
plot.set_legend_properties(BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES[backend])
96+
fig = line_plot(
97+
lines=lines + multistart_lines,
98+
backend=backend,
99+
xlabel="No. of criterion evaluations",
100+
ylabel="Criterion value",
101+
template=template,
102+
legend_properties=BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES[backend],
103+
)
108104

109-
return plot.figure
105+
return fig
110106

111107

112108
def _harmonize_inputs_to_dict(

tests/optimagic/visualization/test_history_plots.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from numpy.testing import assert_array_equal
77

88
import optimagic as om
9-
from optimagic.exceptions import InvalidPlottingBackendError
109
from optimagic.logging import SQLiteLogOptions
1110
from optimagic.optimization.optimize import minimize
1211
from optimagic.parameters.bounds import Bounds
@@ -145,9 +144,6 @@ def test_criterion_plot_wrong_inputs():
145144
with pytest.raises(ValueError):
146145
criterion_plot(["bla", "bla"], names="blub")
147146

148-
with pytest.raises(InvalidPlottingBackendError):
149-
criterion_plot("bla", backend="blub")
150-
151147

152148
@pytest.mark.parametrize("backend", ["plotly", "matplotlib"])
153149
def test_criterion_plot_different_backends(minimize_result, backend):

0 commit comments

Comments
 (0)