|
1 |
| -import abc |
2 |
| -from typing import Any |
| 1 | +from typing import Any, Literal |
3 | 2 |
|
4 |
| -import plotly.express as px |
5 | 3 | import plotly.graph_objects as go
|
6 | 4 |
|
7 | 5 | from optimagic.config import IS_MATPLOTLIB_INSTALLED
|
|
20 | 18 | plt.ioff()
|
21 | 19 |
|
22 | 20 |
|
23 |
| -class PlotBackend(abc.ABC): |
24 |
| - is_available: bool |
25 |
| - default_template: str |
26 |
| - |
27 |
| - @classmethod |
28 |
| - @abc.abstractmethod |
29 |
| - def get_default_palette(cls) -> list: |
30 |
| - pass |
31 |
| - |
32 |
| - @abc.abstractmethod |
33 |
| - def __init__(self, template: str | None): |
34 |
| - if template is None: |
35 |
| - template = self.default_template |
36 |
| - |
37 |
| - self.template = template |
38 |
| - self.figure: Any = None |
39 |
| - |
40 |
| - @abc.abstractmethod |
41 |
| - def add_lines(self, lines: list[LineData]) -> None: |
42 |
| - pass |
43 |
| - |
44 |
| - @abc.abstractmethod |
45 |
| - def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: |
46 |
| - pass |
47 |
| - |
48 |
| - @abc.abstractmethod |
49 |
| - def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: |
50 |
| - pass |
51 |
| - |
52 |
| - |
53 |
| -class PlotlyBackend(PlotBackend): |
54 |
| - is_available: bool = True |
55 |
| - default_template: str = "simple_white" |
56 |
| - |
57 |
| - @classmethod |
58 |
| - def get_default_palette(cls) -> list: |
59 |
| - return px.colors.qualitative.Set2 |
60 |
| - |
61 |
| - def __init__(self, template: str | None): |
62 |
| - super().__init__(template) |
63 |
| - self._fig = go.Figure() |
64 |
| - self._fig.update_layout(template=self.template) |
65 |
| - self.figure = self._fig |
66 |
| - |
67 |
| - def add_lines(self, lines: list[LineData]) -> None: |
68 |
| - for line in lines: |
69 |
| - trace = go.Scatter( |
70 |
| - x=line.x, |
71 |
| - y=line.y, |
72 |
| - name=line.name, |
73 |
| - mode="lines", |
74 |
| - line_color=line.color, |
75 |
| - showlegend=line.show_in_legend, |
76 |
| - connectgaps=True, |
77 |
| - ) |
78 |
| - self._fig.add_trace(trace) |
79 |
| - |
80 |
| - def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: |
81 |
| - self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel) |
82 |
| - |
83 |
| - def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: |
84 |
| - self._fig.update_layout(legend=legend_properties) |
85 |
| - |
86 |
| - |
87 |
| -class MatplotlibBackend(PlotBackend): |
88 |
| - is_available: bool = IS_MATPLOTLIB_INSTALLED |
89 |
| - default_template: str = "default" |
90 |
| - |
91 |
| - @classmethod |
92 |
| - def get_default_palette(cls) -> list: |
93 |
| - return [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)] |
94 |
| - |
95 |
| - def __init__(self, template: str | None): |
96 |
| - super().__init__(template) |
97 |
| - plt.style.use(self.template) |
98 |
| - self._fig, self._ax = plt.subplots() |
99 |
| - self.figure = self._fig |
100 |
| - |
101 |
| - def add_lines(self, lines: list[LineData]) -> None: |
102 |
| - for line in lines: |
103 |
| - self._ax.plot( |
104 |
| - line.x, |
105 |
| - line.y, |
106 |
| - color=line.color, |
107 |
| - label=line.name if line.show_in_legend else None, |
108 |
| - ) |
| 21 | +def _line_plot_plotly( |
| 22 | + lines: list[LineData], |
| 23 | + *, |
| 24 | + title: str | None, |
| 25 | + xlabel: str | None, |
| 26 | + ylabel: str | None, |
| 27 | + template: str | None, |
| 28 | + height: int | None, |
| 29 | + width: int | None, |
| 30 | + legend_properties: dict[str, Any] | None, |
| 31 | +) -> go.Figure: |
| 32 | + fig = go.Figure() |
| 33 | + |
| 34 | + for line in lines: |
| 35 | + trace = go.Scatter( |
| 36 | + x=line.x, |
| 37 | + y=line.y, |
| 38 | + name=line.name, |
| 39 | + line_color=line.color, |
| 40 | + mode="lines", |
| 41 | + ) |
| 42 | + fig.add_trace(trace) |
| 43 | + |
| 44 | + fig.update_layout( |
| 45 | + title=title, |
| 46 | + xaxis_title=xlabel, |
| 47 | + yaxis_title=ylabel, |
| 48 | + template=template, |
| 49 | + height=height, |
| 50 | + width=width, |
| 51 | + ) |
| 52 | + |
| 53 | + if legend_properties: |
| 54 | + fig.update_layout(legend=legend_properties) |
| 55 | + |
| 56 | + return fig |
| 57 | + |
| 58 | + |
| 59 | +def _line_plot_matplotlib( |
| 60 | + lines: list[LineData], |
| 61 | + *, |
| 62 | + title: str | None, |
| 63 | + xlabel: str | None, |
| 64 | + ylabel: str | None, |
| 65 | + template: str | None, |
| 66 | + height: int | None, |
| 67 | + width: int | None, |
| 68 | + legend_properties: dict[str, Any] | None, |
| 69 | +) -> "plt.Figure": |
| 70 | + if template is not None: |
| 71 | + plt.style.use(template) |
| 72 | + fig, ax = plt.subplots(figsize=(width, height) if width and height else None) |
| 73 | + |
| 74 | + for line in lines: |
| 75 | + ax.plot( |
| 76 | + line.x, |
| 77 | + line.y, |
| 78 | + label=line.name if line.show_in_legend else None, |
| 79 | + color=line.color, |
| 80 | + ) |
109 | 81 |
|
110 |
| - def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: |
111 |
| - self._ax.set(xlabel=xlabel, ylabel=ylabel) |
| 82 | + ax.set(title=title, xlabel=xlabel, ylabel=ylabel) |
| 83 | + if legend_properties: |
| 84 | + ax.legend(**legend_properties) |
112 | 85 |
|
113 |
| - def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: |
114 |
| - self._ax.legend(**legend_properties) |
| 86 | + return fig |
115 | 87 |
|
116 | 88 |
|
117 |
| -PLOT_BACKEND_CLASSES = { |
118 |
| - "plotly": PlotlyBackend, |
119 |
| - "matplotlib": MatplotlibBackend, |
| 89 | +BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION = { |
| 90 | + "plotly": (True, _line_plot_plotly), |
| 91 | + "matplotlib": (IS_MATPLOTLIB_INSTALLED, _line_plot_matplotlib), |
120 | 92 | }
|
121 | 93 |
|
122 | 94 |
|
123 |
| -def get_plot_backend_class(backend_name: str) -> type[PlotBackend]: |
124 |
| - if backend_name not in PLOT_BACKEND_CLASSES: |
| 95 | +def line_plot( |
| 96 | + lines: list[LineData], |
| 97 | + backend: Literal["plotly", "matplotlib"] = "plotly", |
| 98 | + *, |
| 99 | + title: str | None = None, |
| 100 | + xlabel: str | None = None, |
| 101 | + ylabel: str | None = None, |
| 102 | + template: str | None = None, |
| 103 | + height: int | None = None, |
| 104 | + width: int | None = None, |
| 105 | + legend_properties: dict[str, Any] | None = None, |
| 106 | +) -> Any: |
| 107 | + """Create a line plot corresponding to the specified backend. |
| 108 | +
|
| 109 | + Args: |
| 110 | + lines: List of objects each containing data for a line in the plot. |
| 111 | + backend: The backend to use for plotting. |
| 112 | + title: Title of the plot. |
| 113 | + xlabel: Label for the x-axis. |
| 114 | + ylabel: Label for the y-axis. |
| 115 | + template: Backend-specific template for styling the plot. |
| 116 | + height: Height of the plot (in pixels). |
| 117 | + width: Width of the plot (in pixels). |
| 118 | + legend_properties: Backend-specific properties for the legend. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + A figure object corresponding to the specified backend. |
| 122 | +
|
| 123 | + """ |
| 124 | + if backend not in BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: |
125 | 125 | msg = (
|
126 |
| - f"Invalid backend name '{backend_name}'. " |
127 |
| - f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}." |
| 126 | + f"Invalid plotting backend '{backend}'. " |
| 127 | + f"Available backends: " |
| 128 | + f"{', '.join(BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION.keys())}" |
128 | 129 | )
|
129 | 130 | raise InvalidPlottingBackendError(msg)
|
130 | 131 |
|
131 |
| - return _get_backend_if_installed(backend_name) |
132 |
| - |
133 |
| - |
134 |
| -def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]: |
135 |
| - plot_cls = PLOT_BACKEND_CLASSES[backend_name] |
| 132 | + _is_backend_available, _line_plot_backend_function = ( |
| 133 | + BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend] |
| 134 | + ) |
136 | 135 |
|
137 |
| - if not plot_cls.is_available: |
| 136 | + if not _is_backend_available: |
138 | 137 | msg = (
|
139 |
| - f"The '{backend_name}' backend is not installed. " |
140 |
| - f"Install the package using either 'pip install {backend_name}' or " |
141 |
| - f"'conda install -c conda-forge {backend_name}'" |
| 138 | + f"The {backend} backend is not installed. " |
| 139 | + f"Install the package using either 'pip install {backend}' or " |
| 140 | + f"'conda install -c conda-forge {backend}'" |
142 | 141 | )
|
143 | 142 | raise NotInstalledError(msg)
|
144 | 143 |
|
145 |
| - return plot_cls |
| 144 | + fig = _line_plot_backend_function( |
| 145 | + lines, |
| 146 | + title=title, |
| 147 | + xlabel=xlabel, |
| 148 | + ylabel=ylabel, |
| 149 | + template=template, |
| 150 | + height=height, |
| 151 | + width=width, |
| 152 | + legend_properties=legend_properties, |
| 153 | + ) |
| 154 | + |
| 155 | + return fig |
0 commit comments