1
1
import inspect
2
- import itertools
3
2
from pathlib import Path
4
3
from typing import Any
5
4
6
- import matplotlib as mpl
7
- import matplotlib .pyplot as plt
8
5
import numpy as np
9
6
import plotly .graph_objects as go
10
7
from pybaum import leaf_names , tree_flatten , tree_just_flatten , tree_unflatten
11
8
12
- from optimagic .config import PLOT_DEFAULTS , PLOTLY_TEMPLATE
9
+ from optimagic .config import PLOTLY_TEMPLATE
13
10
from optimagic .logging .logger import LogReader , SQLiteLogOptions
14
11
from optimagic .optimization .algorithm import Algorithm
15
12
from optimagic .optimization .history import History
16
13
from optimagic .optimization .optimize_result import OptimizeResult
17
14
from optimagic .parameters .tree_registry import get_registry
18
15
from optimagic .typing import Direction
16
+ from optimagic .visualization .backends import BackendRegistry , PlotConfig
17
+ from optimagic .visualization .plotting_utilities import get_palette , get_template
19
18
20
19
21
20
def criterion_plot (
22
21
results ,
23
22
names = None ,
24
- backend = "plotly" ,
25
23
max_evaluations = None ,
24
+ backend = "plotly" ,
26
25
template = None ,
27
26
palette = None ,
28
27
stack_multistart = False ,
@@ -36,8 +35,8 @@ def criterion_plot(
36
35
dict of) optimization results with collected history. If dict, then the
37
36
key is used as the name in a legend.
38
37
names (Union[List[str], str]): Names corresponding to res or entries in res.
39
- backend (str): The backend to use for plotting. Default is "plotly".
40
38
max_evaluations (int): Clip the criterion history after that many entries.
39
+ backend (str): The backend to use for plotting. Default is "plotly".
41
40
template (str): The template for the figure. Default is "plotly_white".
42
41
palette (Union[List[str], str]): The coloring palette for traces. Default is
43
42
"qualitative.Plotly".
@@ -50,7 +49,7 @@ def criterion_plot(
50
49
optimization are visualized. Default is False.
51
50
52
51
Returns:
53
- Figure object returned by the chosen backend.
52
+ Native figure object returned by the chosen backend.
54
53
55
54
"""
56
55
# ==================================================================================
@@ -59,16 +58,8 @@ def criterion_plot(
59
58
60
59
results = _harmonize_inputs_to_dict (results , names )
61
60
62
- if template is None :
63
- template = PLOT_DEFAULTS [backend ]["template" ]
64
- if palette is None :
65
- palette = PLOT_DEFAULTS [backend ]["palette" ]
66
-
67
- if isinstance (palette , mpl .colors .Colormap ):
68
- palette = [palette (i ) for i in range (palette .N )]
69
- if not isinstance (palette , list ):
70
- palette = [palette ]
71
- palette = itertools .cycle (palette )
61
+ template = get_template (backend , template )
62
+ palette = get_palette (backend , palette )
72
63
73
64
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
74
65
@@ -98,15 +89,22 @@ def criterion_plot(
98
89
# Create figure
99
90
# ==================================================================================
100
91
101
- fig , plot_func , label_func = _get_plot_backend ( backend )
102
-
103
- plot_multistart = (
104
- len ( data ) == 1 and data [ 0 ][ "is_multistart" ] and not stack_multistart
92
+ plot_config = PlotConfig (
93
+ template = template ,
94
+ plotly_legend = { "yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
95
+ matplotlib_legend = { "loc" : "upper right" },
105
96
)
106
97
98
+ _backend_wrapper = BackendRegistry .get_backend_wrapper (backend )
99
+ fig = _backend_wrapper (plot_config )
100
+
107
101
# ==================================================================================
108
102
# Plot multistart paths
109
103
104
+ plot_multistart = (
105
+ len (data ) == 1 and data [0 ]["is_multistart" ] and not stack_multistart
106
+ )
107
+
110
108
if plot_multistart :
111
109
scatter_kws = {
112
110
"connectgaps" : True ,
@@ -119,8 +117,7 @@ def criterion_plot(
119
117
if max_evaluations is not None and len (history ) > max_evaluations :
120
118
history = history [:max_evaluations ]
121
119
122
- plot_func (
123
- fig ,
120
+ fig .lineplot (
124
121
x = np .arange (len (history )),
125
122
y = history ,
126
123
name = None ,
@@ -144,30 +141,23 @@ def criterion_plot(
144
141
145
142
scatter_kws = {
146
143
"connectgaps" : True ,
147
- "showlegend" : not plot_multistart ,
144
+ "showlegend" : True ,
148
145
}
149
146
150
- _color = next (palette )
151
-
152
- plot_func (
153
- fig ,
147
+ fig .lineplot (
154
148
x = np .arange (len (history )),
155
149
y = history ,
156
150
name = "best result" if plot_multistart else _data ["name" ],
157
- color = _color ,
151
+ color = next ( palette ) ,
158
152
plotly_scatter_kws = scatter_kws ,
159
153
)
160
154
161
- label_func (
162
- fig ,
163
- template = template ,
155
+ fig .post_plot (
164
156
xlabel = "No. of criterion evaluations" ,
165
157
ylabel = "Criterion value" ,
166
- plotly_legend = {"yanchor" : "top" , "xanchor" : "right" , "y" : 0.95 , "x" : 0.95 },
167
- matplotlib_legend = {"loc" : "upper right" },
168
158
)
169
159
170
- return fig
160
+ return fig . return_obj ()
171
161
172
162
173
163
def _harmonize_inputs_to_dict (results , names ):
@@ -461,54 +451,3 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
461
451
task = len (stacked ["criterion" ]) * [None ],
462
452
batches = list (range (len (stacked ["criterion" ]))),
463
453
)
464
-
465
-
466
- def _get_plot_backend (backend ):
467
- backends = {
468
- "plotly" : (
469
- go .Figure (),
470
- _plot_plotly ,
471
- _label_plotly ,
472
- ),
473
- "matplotlib" : (
474
- plt .subplots ()[1 ],
475
- _plot_matplotlib ,
476
- _label_matplotlib ,
477
- ),
478
- }
479
-
480
- if backend not in backends :
481
- msg = (
482
- f"Backend '{ backend } ' is not supported. "
483
- f"Supported backends are: { ', ' .join (backends .keys ())} ."
484
- )
485
- raise ValueError (msg )
486
-
487
- return backends [backend ]
488
-
489
-
490
- def _plot_plotly (fig , * , x , y , name , color , plotly_scatter_kws , ** kwargs ):
491
- trace = go .Scatter (
492
- x = x , y = y , mode = "lines" , name = name , line_color = color , ** plotly_scatter_kws
493
- )
494
- fig .add_trace (trace )
495
- return fig
496
-
497
-
498
- def _label_plotly (fig , * , template , xlabel , ylabel , plotly_legend , ** kwargs ):
499
- fig .update_layout (
500
- template = template ,
501
- xaxis_title_text = xlabel ,
502
- yaxis_title_text = ylabel ,
503
- legend = plotly_legend ,
504
- )
505
-
506
-
507
- def _plot_matplotlib (ax , * , x , y , name , color , ** kwargs ):
508
- ax .plot (x , y , label = name , color = color )
509
- return ax
510
-
511
-
512
- def _label_matplotlib (ax , * , xlabel , ylabel , matplotlib_legend , ** kwargs ):
513
- ax .set (xlabel = xlabel , ylabel = ylabel )
514
- ax .legend (** matplotlib_legend )
0 commit comments