3
3
from pathlib import Path
4
4
from typing import Any
5
5
6
+ import matplotlib .pyplot as plt
6
7
import numpy as np
7
8
import plotly .graph_objects as go
8
9
from pybaum import leaf_names , tree_flatten , tree_just_flatten , tree_unflatten
9
10
10
- from optimagic .config import PLOTLY_PALETTE , PLOTLY_TEMPLATE
11
+ from optimagic .config import PLOT_DEFAULTS , PLOTLY_TEMPLATE
11
12
from optimagic .logging .logger import LogReader , SQLiteLogOptions
12
13
from optimagic .optimization .algorithm import Algorithm
13
14
from optimagic .optimization .history import History
19
20
def criterion_plot (
20
21
results ,
21
22
names = None ,
23
+ backend = "plotly" ,
22
24
max_evaluations = None ,
23
- template = PLOTLY_TEMPLATE ,
24
- palette = PLOTLY_PALETTE ,
25
+ template = None ,
26
+ palette = None ,
25
27
stack_multistart = False ,
26
28
monotone = False ,
27
29
show_exploration = False ,
@@ -33,6 +35,7 @@ def criterion_plot(
33
35
dict of) optimization results with collected history. If dict, then the
34
36
key is used as the name in a legend.
35
37
names (Union[List[str], str]): Names corresponding to res or entries in res.
38
+ backend (str): The backend to use for plotting. Default is "plotly".
36
39
max_evaluations (int): Clip the criterion history after that many entries.
37
40
template (str): The template for the figure. Default is "plotly_white".
38
41
palette (Union[List[str], str]): The coloring palette for traces. Default is
@@ -46,7 +49,7 @@ def criterion_plot(
46
49
optimization are visualized. Default is False.
47
50
48
51
Returns:
49
- plotly.graph_objs._figure. Figure: The figure .
52
+ Figure object returned by the chosen backend .
50
53
51
54
"""
52
55
# ==================================================================================
@@ -55,6 +58,11 @@ def criterion_plot(
55
58
56
59
results = _harmonize_inputs_to_dict (results , names )
57
60
61
+ if template is None :
62
+ template = PLOT_DEFAULTS [backend ]["template" ]
63
+ if palette is None :
64
+ palette = PLOT_DEFAULTS [backend ]["palette" ]
65
+
58
66
if not isinstance (palette , list ):
59
67
palette = [palette ]
60
68
palette = itertools .cycle (palette )
@@ -87,7 +95,7 @@ def criterion_plot(
87
95
# Create figure
88
96
# ==================================================================================
89
97
90
- fig = go . Figure ( )
98
+ fig , plot_func , label_func = _get_plot_backend ( backend )
91
99
92
100
plot_multistart = (
93
101
len (data ) == 1 and data [0 ]["is_multistart" ] and not stack_multistart
@@ -102,21 +110,20 @@ def criterion_plot(
102
110
"showlegend" : False ,
103
111
}
104
112
105
- for i , local_history in enumerate ( data [0 ]["local_histories" ]) :
113
+ for local_history in data [0 ]["local_histories" ]:
106
114
history = getattr (local_history , fun_or_monotone_fun )
107
115
108
116
if max_evaluations is not None and len (history ) > max_evaluations :
109
117
history = history [:max_evaluations ]
110
118
111
- trace = go .Scatter (
119
+ plot_func (
120
+ fig ,
112
121
x = np .arange (len (history )),
113
122
y = history ,
114
- mode = "lines" ,
115
- name = str (i ),
116
- line_color = "#bab0ac" ,
117
- ** scatter_kws ,
123
+ name = None ,
124
+ color = "#bab0ac" ,
125
+ scatter_kws = scatter_kws ,
118
126
)
119
- fig .add_trace (trace )
120
127
121
128
# ==================================================================================
122
129
# Plot main optimization objects
@@ -138,31 +145,26 @@ def criterion_plot(
138
145
}
139
146
140
147
_color = next (palette )
141
- if not isinstance (_color , str ):
142
- msg = "highlight_palette needs to be a string or list of strings, but its "
143
- f"entry is of type { type (_color )} ."
144
- raise TypeError (msg )
145
148
146
- line_kws = {
147
- "color" : _color ,
148
- }
149
-
150
- trace = go .Scatter (
149
+ plot_func (
150
+ fig ,
151
151
x = np .arange (len (history )),
152
152
y = history ,
153
- mode = "lines" ,
154
153
name = "best result" if plot_multistart else _data ["name" ],
155
- line = line_kws ,
156
- ** scatter_kws ,
154
+ color = _color ,
155
+ mode = "lines" ,
156
+ scatter_kws = scatter_kws ,
157
157
)
158
- fig .add_trace (trace )
159
158
160
- fig .update_layout (
159
+ label_func (
160
+ fig ,
161
161
template = template ,
162
- xaxis_title_text = "No. of criterion evaluations" ,
163
- yaxis_title_text = "Criterion value" ,
164
- legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
162
+ xlabel = "No. of criterion evaluations" ,
163
+ ylabel = "Criterion value" ,
164
+ plotly_legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
165
+ matplotlib_legend = "upper right" ,
165
166
)
167
+
166
168
return fig
167
169
168
170
@@ -457,3 +459,55 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
457
459
task = len (stacked ["criterion" ]) * [None ],
458
460
batches = list (range (len (stacked ["criterion" ]))),
459
461
)
462
+
463
+
464
+ def _get_plot_backend (backend ):
465
+ backends = {
466
+ "plotly" : (
467
+ go .Figure (),
468
+ _plot_plotly ,
469
+ _label_plotly ,
470
+ ),
471
+ "matplotlib" : (
472
+ plt ,
473
+ _plot_matplotlib ,
474
+ _label_matplotlib ,
475
+ ),
476
+ }
477
+
478
+ if backend not in backends :
479
+ msg = (
480
+ f"Backend '{ backend } ' is not supported. "
481
+ f"Supported backends are: { ', ' .join (backends .keys ())} ."
482
+ )
483
+ raise ValueError (msg )
484
+
485
+ return backends [backend ]
486
+
487
+
488
+ def _plot_plotly (fig , * , x , y , name , color , scatter_kws , ** kwargs ):
489
+ trace = go .Scatter (
490
+ x = x , y = y , mode = "lines" , name = name , line_color = color , ** scatter_kws
491
+ )
492
+ fig .add_trace (trace )
493
+ return fig
494
+
495
+
496
+ def _label_plotly (fig , * , template , xlabel , ylabel , plotly_legend , ** kwargs ):
497
+ fig .update_layout (
498
+ template = template ,
499
+ xaxis_title_text = xlabel ,
500
+ yaxis_title_text = ylabel ,
501
+ legend = plotly_legend ,
502
+ )
503
+
504
+
505
+ def _plot_matplotlib (plt , * , x , y , name , color , ** kwargs ):
506
+ plt .plot (x , y , label = name , color = color )
507
+ return plt
508
+
509
+
510
+ def _label_matplotlib (plt , * , xlabel , ylabel , matplotlib_legend , ** kwargs ):
511
+ plt .gca ().set_xlabel (xlabel )
512
+ plt .gca ().set_ylabel (ylabel )
513
+ plt .legend (loc = matplotlib_legend )
0 commit comments