3
3
from pathlib import Path
4
4
from typing import Any
5
5
6
+ import matplotlib as mpl
7
+ import matplotlib .pyplot as plt
6
8
import numpy as np
7
9
import plotly .graph_objects as go
8
10
from pybaum import leaf_names , tree_flatten , tree_just_flatten , tree_unflatten
9
11
10
- from optimagic .config import PLOTLY_PALETTE , PLOTLY_TEMPLATE
12
+ from optimagic .config import PLOT_DEFAULTS , PLOTLY_TEMPLATE
11
13
from optimagic .logging .logger import LogReader , SQLiteLogOptions
12
14
from optimagic .optimization .algorithm import Algorithm
13
15
from optimagic .optimization .history import History
19
21
def criterion_plot (
20
22
results ,
21
23
names = None ,
24
+ backend = "plotly" ,
22
25
max_evaluations = None ,
23
- template = PLOTLY_TEMPLATE ,
24
- palette = PLOTLY_PALETTE ,
26
+ template = None ,
27
+ palette = None ,
25
28
stack_multistart = False ,
26
29
monotone = False ,
27
30
show_exploration = False ,
@@ -33,6 +36,7 @@ def criterion_plot(
33
36
dict of) optimization results with collected history. If dict, then the
34
37
key is used as the name in a legend.
35
38
names (Union[List[str], str]): Names corresponding to res or entries in res.
39
+ backend (str): The backend to use for plotting. Default is "plotly".
36
40
max_evaluations (int): Clip the criterion history after that many entries.
37
41
template (str): The template for the figure. Default is "plotly_white".
38
42
palette (Union[List[str], str]): The coloring palette for traces. Default is
@@ -46,7 +50,7 @@ def criterion_plot(
46
50
optimization are visualized. Default is False.
47
51
48
52
Returns:
49
- plotly.graph_objs._figure. Figure: The figure .
53
+ Figure object returned by the chosen backend .
50
54
51
55
"""
52
56
# ==================================================================================
@@ -55,6 +59,13 @@ def criterion_plot(
55
59
56
60
results = _harmonize_inputs_to_dict (results , names )
57
61
62
+ if template is None :
63
+ template = PLOT_DEFAULTS [backend ]["template" ]
64
+ if palette is None :
65
+ palette = PLOT_DEFAULTS [backend ]["palette" ]
66
+
67
+ if isinstance (palette , mpl .colors .Colormap ):
68
+ palette = [palette (i ) for i in range (palette .N )]
58
69
if not isinstance (palette , list ):
59
70
palette = [palette ]
60
71
palette = itertools .cycle (palette )
@@ -87,7 +98,7 @@ def criterion_plot(
87
98
# Create figure
88
99
# ==================================================================================
89
100
90
- fig = go . Figure ( )
101
+ fig , plot_func , label_func = _get_plot_backend ( backend )
91
102
92
103
plot_multistart = (
93
104
len (data ) == 1 and data [0 ]["is_multistart" ] and not stack_multistart
@@ -102,21 +113,20 @@ def criterion_plot(
102
113
"showlegend" : False ,
103
114
}
104
115
105
- for i , local_history in enumerate ( data [0 ]["local_histories" ]) :
116
+ for local_history in data [0 ]["local_histories" ]:
106
117
history = getattr (local_history , fun_or_monotone_fun )
107
118
108
119
if max_evaluations is not None and len (history ) > max_evaluations :
109
120
history = history [:max_evaluations ]
110
121
111
- trace = go .Scatter (
122
+ plot_func (
123
+ fig ,
112
124
x = np .arange (len (history )),
113
125
y = history ,
114
- mode = "lines" ,
115
- name = str (i ),
116
- line_color = "#bab0ac" ,
117
- ** scatter_kws ,
126
+ name = None ,
127
+ color = "#bab0ac" ,
128
+ plotly_scatter_kws = scatter_kws ,
118
129
)
119
- fig .add_trace (trace )
120
130
121
131
# ==================================================================================
122
132
# Plot main optimization objects
@@ -138,31 +148,25 @@ def criterion_plot(
138
148
}
139
149
140
150
_color = next (palette )
141
- if not isinstance (_color , str ):
142
- msg = "highlight_palette needs to be a string or list of strings, but its "
143
- f"entry is of type { type (_color )} ."
144
- raise TypeError (msg )
145
151
146
- line_kws = {
147
- "color" : _color ,
148
- }
149
-
150
- trace = go .Scatter (
152
+ plot_func (
153
+ fig ,
151
154
x = np .arange (len (history )),
152
155
y = history ,
153
- mode = "lines" ,
154
156
name = "best result" if plot_multistart else _data ["name" ],
155
- line = line_kws ,
156
- ** scatter_kws ,
157
+ color = _color ,
158
+ plotly_scatter_kws = scatter_kws ,
157
159
)
158
- fig .add_trace (trace )
159
160
160
- fig .update_layout (
161
+ label_func (
162
+ fig ,
161
163
template = template ,
162
- xaxis_title_text = "No. of criterion evaluations" ,
163
- yaxis_title_text = "Criterion value" ,
164
- legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
164
+ xlabel = "No. of criterion evaluations" ,
165
+ ylabel = "Criterion value" ,
166
+ plotly_legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
167
+ matplotlib_legend = {"loc" : "upper right" },
165
168
)
169
+
166
170
return fig
167
171
168
172
@@ -457,3 +461,54 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
457
461
task = len (stacked ["criterion" ]) * [None ],
458
462
batches = list (range (len (stacked ["criterion" ]))),
459
463
)
464
+
465
+
466
+ def _get_plot_backend (backend ):
467
+ backends = {
468
+ "plotly" : (
469
+ go .Figure (),
470
+ _plot_plotly ,
471
+ _label_plotly ,
472
+ ),
473
+ "matplotlib" : (
474
+ plt .subplots ()[1 ],
475
+ _plot_matplotlib ,
476
+ _label_matplotlib ,
477
+ ),
478
+ }
479
+
480
+ if backend not in backends :
481
+ msg = (
482
+ f"Backend '{ backend } ' is not supported. "
483
+ f"Supported backends are: { ', ' .join (backends .keys ())} ."
484
+ )
485
+ raise ValueError (msg )
486
+
487
+ return backends [backend ]
488
+
489
+
490
+ def _plot_plotly (fig , * , x , y , name , color , plotly_scatter_kws , ** kwargs ):
491
+ trace = go .Scatter (
492
+ x = x , y = y , mode = "lines" , name = name , line_color = color , ** plotly_scatter_kws
493
+ )
494
+ fig .add_trace (trace )
495
+ return fig
496
+
497
+
498
+ def _label_plotly (fig , * , template , xlabel , ylabel , plotly_legend , ** kwargs ):
499
+ fig .update_layout (
500
+ template = template ,
501
+ xaxis_title_text = xlabel ,
502
+ yaxis_title_text = ylabel ,
503
+ legend = plotly_legend ,
504
+ )
505
+
506
+
507
+ def _plot_matplotlib (ax , * , x , y , name , color , ** kwargs ):
508
+ ax .plot (x , y , label = name , color = color )
509
+ return ax
510
+
511
+
512
+ def _label_matplotlib (ax , * , xlabel , ylabel , matplotlib_legend , ** kwargs ):
513
+ ax .set (xlabel = xlabel , ylabel = ylabel )
514
+ ax .legend (** matplotlib_legend )
0 commit comments