1
1
import abc
2
2
from typing import Any
3
3
4
- import matplotlib as mpl
5
- import matplotlib .pyplot as plt
6
4
import plotly .express as px
7
5
import plotly .graph_objects as go
8
6
7
+ from optimagic .config import IS_MATPLOTLIB_INSTALLED
8
+ from optimagic .exceptions import NotInstalledError
9
9
from optimagic .visualization .plotting_utilities import LineData
10
10
11
+ if IS_MATPLOTLIB_INSTALLED :
12
+ import matplotlib as mpl
13
+ import matplotlib .pyplot as plt
14
+
11
15
12
16
class PlotBackend (abc .ABC ):
13
17
default_template : str
@@ -30,7 +34,7 @@ def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> No
30
34
pass
31
35
32
36
@abc .abstractmethod
33
- def set_legend_props (self , legend_props : dict [str , Any ]) -> None :
37
+ def set_legend_properties (self , legend_properties : dict [str , Any ]) -> None :
34
38
pass
35
39
36
40
@@ -60,39 +64,45 @@ def add_lines(self, lines: list[LineData]) -> None:
60
64
def set_labels (self , xlabel : str | None = None , ylabel : str | None = None ) -> None :
61
65
self ._fig .update_layout (xaxis_title_text = xlabel , yaxis_title_text = ylabel )
62
66
63
- def set_legend_props (self , legend_props : dict [str , Any ]) -> None :
64
- self ._fig .update_layout (legend = legend_props )
67
+ def set_legend_properties (self , legend_properties : dict [str , Any ]) -> None :
68
+ self ._fig .update_layout (legend = legend_properties )
65
69
66
70
67
- class MatplotlibBackend (PlotBackend ):
68
- default_template : str = "default"
69
- default_palette : list = list (mpl .colormaps ["Set2" ].colors )
71
+ if IS_MATPLOTLIB_INSTALLED :
70
72
71
- def __init__ ( self , template : str | None ):
72
- super (). __init__ ( template )
73
- plt . style . use ( self . template )
74
- self . _fig , self . _ax = plt . subplots ( )
75
- self . figure = self . _fig
73
+ class MatplotlibBackend ( PlotBackend ):
74
+ default_template : str = "default"
75
+ default_palette : list = [
76
+ mpl . colormaps [ "Set2" ]( i ) for i in range ( mpl . colormaps [ "Set2" ]. N )
77
+ ]
76
78
77
- def add_lines (self , lines : list [LineData ]) -> None :
78
- for line in lines :
79
- self ._ax .plot (
80
- line .x ,
81
- line .y ,
82
- color = line .color ,
83
- label = line .name if line .show_in_legend else None ,
84
- )
79
+ def __init__ (self , template : str | None ):
80
+ super ().__init__ (template )
81
+ plt .style .use (self .template )
82
+ self ._fig , self ._ax = plt .subplots ()
83
+ self .figure = self ._fig
85
84
86
- def set_labels (self , xlabel : str | None = None , ylabel : str | None = None ) -> None :
87
- self ._ax .set (xlabel = xlabel , ylabel = ylabel )
85
+ def add_lines (self , lines : list [LineData ]) -> None :
86
+ for line in lines :
87
+ self ._ax .plot (
88
+ line .x ,
89
+ line .y ,
90
+ color = line .color ,
91
+ label = line .name if line .show_in_legend else None ,
92
+ )
88
93
89
- def set_legend_props (self , legend_props : dict [str , Any ]) -> None :
90
- self ._ax .legend (** legend_props )
94
+ def set_labels (
95
+ self , xlabel : str | None = None , ylabel : str | None = None
96
+ ) -> None :
97
+ self ._ax .set (xlabel = xlabel , ylabel = ylabel )
98
+
99
+ def set_legend_properties (self , legend_properties : dict [str , Any ]) -> None :
100
+ self ._ax .legend (** legend_properties )
91
101
92
102
93
103
PLOT_BACKEND_CLASSES = {
94
104
"plotly" : PlotlyBackend ,
95
- "matplotlib" : MatplotlibBackend ,
105
+ "matplotlib" : MatplotlibBackend if IS_MATPLOTLIB_INSTALLED else None ,
96
106
}
97
107
98
108
@@ -104,4 +114,18 @@ def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
104
114
)
105
115
raise ValueError (msg )
106
116
107
- return PLOT_BACKEND_CLASSES [backend_name ]
117
+ return _get_backend_if_installed (backend_name )
118
+
119
+
120
+ def _get_backend_if_installed (backend_name : str ) -> type [PlotBackend ]:
121
+ plot_cls = PLOT_BACKEND_CLASSES [backend_name ]
122
+
123
+ if plot_cls is None :
124
+ msg = (
125
+ f"The '{ backend_name } ' backend is not installed. "
126
+ f"Install the package using either 'pip install { backend_name } ' or "
127
+ f"'conda install -c conda-forge { backend_name } '"
128
+ )
129
+ raise NotInstalledError (msg )
130
+
131
+ return plot_cls
0 commit comments