Skip to content

Commit e167b49

Browse files
committed
Make matplotlib an optional dependency and minor refactor for clarity.
1 parent 0159fe6 commit e167b49

File tree

11 files changed

+70
-41
lines changed

11 files changed

+70
-41
lines changed

.tools/envs/testenv-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies:
1919
- numpy >= 2 # run, tests
2020
- pandas # run, tests
2121
- plotly>=6.2 # run, tests
22-
- matplotlib # run, tests
22+
- matplotlib # tests
2323
- pybaum>=0.1.2 # run, tests
2424
- scipy>=1.2.1 # run, tests
2525
- sqlalchemy # run, tests

.tools/envs/testenv-nevergrad.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- numpy >= 2 # run, tests
1818
- pandas # run, tests
1919
- plotly>=6.2 # run, tests
20-
- matplotlib # run, tests
20+
- matplotlib # tests
2121
- pybaum>=0.1.2 # run, tests
2222
- scipy>=1.2.1 # run, tests
2323
- sqlalchemy # run, tests

.tools/envs/testenv-numpy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly>=6.2 # run, tests
20-
- matplotlib # run, tests
20+
- matplotlib # tests
2121
- pybaum>=0.1.2 # run, tests
2222
- scipy>=1.2.1 # run, tests
2323
- sqlalchemy # run, tests

.tools/envs/testenv-others.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- numpy >= 2 # run, tests
1818
- pandas # run, tests
1919
- plotly>=6.2 # run, tests
20-
- matplotlib # run, tests
20+
- matplotlib # tests
2121
- pybaum>=0.1.2 # run, tests
2222
- scipy>=1.2.1 # run, tests
2323
- sqlalchemy # run, tests

.tools/envs/testenv-pandas.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- cloudpickle # run, tests
1818
- joblib # run, tests
1919
- plotly>=6.2 # run, tests
20-
- matplotlib # run, tests
20+
- matplotlib # tests
2121
- pybaum>=0.1.2 # run, tests
2222
- scipy>=1.2.1 # run, tests
2323
- sqlalchemy # run, tests

.tools/envs/testenv-plotly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- joblib # run, tests
1818
- numpy >= 2 # run, tests
1919
- pandas # run, tests
20-
- matplotlib # run, tests
20+
- matplotlib # tests
2121
- pybaum>=0.1.2 # run, tests
2222
- scipy>=1.2.1 # run, tests
2323
- sqlalchemy # run, tests

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ dependencies:
2121
- numpy >= 2 # run, tests
2222
- pandas # run, tests
2323
- plotly>=6.2 # run, tests
24-
- matplotlib # run, tests
24+
- matplotlib # tests
2525
- pybaum>=0.1.2 # run, tests
2626
- scipy>=1.2.1 # run, tests
2727
- sqlalchemy # run, tests

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ dependencies = [
1111
"numpy",
1212
"pandas",
1313
"plotly",
14-
"matplotlib",
1514
"pybaum>=0.1.2",
1615
"scipy>=1.2.1",
1716
"sqlalchemy>=1.3",

src/optimagic/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _is_installed(module_name: str) -> bool:
2323

2424

2525
# ======================================================================================
26-
# Check Available Packages
26+
# Check Available Optimization Packages
2727
# ======================================================================================
2828

2929
IS_PETSC4PY_INSTALLED = _is_installed("petsc4py")
@@ -40,6 +40,12 @@ def _is_installed(module_name: str) -> bool:
4040
IS_NEVERGRAD_INSTALLED = _is_installed("nevergrad")
4141
IS_BAYESOPT_INSTALLED = _is_installed("bayes_opt")
4242

43+
# ======================================================================================
44+
# Check Available Visualization Packages
45+
# ======================================================================================
46+
47+
IS_MATPLOTLIB_INSTALLED = _is_installed("matplotlib")
48+
4349

4450
# ======================================================================================
4551
# Check if pandas version is newer or equal to version 2.1.0

src/optimagic/visualization/backends.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import abc
22
from typing import Any
33

4-
import matplotlib as mpl
5-
import matplotlib.pyplot as plt
64
import plotly.express as px
75
import plotly.graph_objects as go
86

7+
from optimagic.config import IS_MATPLOTLIB_INSTALLED
8+
from optimagic.exceptions import NotInstalledError
99
from optimagic.visualization.plotting_utilities import LineData
1010

11+
if IS_MATPLOTLIB_INSTALLED:
12+
import matplotlib as mpl
13+
import matplotlib.pyplot as plt
14+
1115

1216
class PlotBackend(abc.ABC):
1317
default_template: str
@@ -30,7 +34,7 @@ def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> No
3034
pass
3135

3236
@abc.abstractmethod
33-
def set_legend_props(self, legend_props: dict[str, Any]) -> None:
37+
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
3438
pass
3539

3640

@@ -60,39 +64,45 @@ def add_lines(self, lines: list[LineData]) -> None:
6064
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
6165
self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel)
6266

63-
def set_legend_props(self, legend_props: dict[str, Any]) -> None:
64-
self._fig.update_layout(legend=legend_props)
67+
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
68+
self._fig.update_layout(legend=legend_properties)
6569

6670

67-
class MatplotlibBackend(PlotBackend):
68-
default_template: str = "default"
69-
default_palette: list = list(mpl.colormaps["Set2"].colors)
71+
if IS_MATPLOTLIB_INSTALLED:
7072

71-
def __init__(self, template: str | None):
72-
super().__init__(template)
73-
plt.style.use(self.template)
74-
self._fig, self._ax = plt.subplots()
75-
self.figure = self._fig
73+
class MatplotlibBackend(PlotBackend):
74+
default_template: str = "default"
75+
default_palette: list = [
76+
mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)
77+
]
7678

77-
def add_lines(self, lines: list[LineData]) -> None:
78-
for line in lines:
79-
self._ax.plot(
80-
line.x,
81-
line.y,
82-
color=line.color,
83-
label=line.name if line.show_in_legend else None,
84-
)
79+
def __init__(self, template: str | None):
80+
super().__init__(template)
81+
plt.style.use(self.template)
82+
self._fig, self._ax = plt.subplots()
83+
self.figure = self._fig
8584

86-
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
87-
self._ax.set(xlabel=xlabel, ylabel=ylabel)
85+
def add_lines(self, lines: list[LineData]) -> None:
86+
for line in lines:
87+
self._ax.plot(
88+
line.x,
89+
line.y,
90+
color=line.color,
91+
label=line.name if line.show_in_legend else None,
92+
)
8893

89-
def set_legend_props(self, legend_props: dict[str, Any]) -> None:
90-
self._ax.legend(**legend_props)
94+
def set_labels(
95+
self, xlabel: str | None = None, ylabel: str | None = None
96+
) -> None:
97+
self._ax.set(xlabel=xlabel, ylabel=ylabel)
98+
99+
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
100+
self._ax.legend(**legend_properties)
91101

92102

93103
PLOT_BACKEND_CLASSES = {
94104
"plotly": PlotlyBackend,
95-
"matplotlib": MatplotlibBackend,
105+
"matplotlib": MatplotlibBackend if IS_MATPLOTLIB_INSTALLED else None,
96106
}
97107

98108

@@ -104,4 +114,18 @@ def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
104114
)
105115
raise ValueError(msg)
106116

107-
return PLOT_BACKEND_CLASSES[backend_name]
117+
return _get_backend_if_installed(backend_name)
118+
119+
120+
def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]:
121+
plot_cls = PLOT_BACKEND_CLASSES[backend_name]
122+
123+
if plot_cls is None:
124+
msg = (
125+
f"The '{backend_name}' backend is not installed. "
126+
f"Install the package using either 'pip install {backend_name}' or "
127+
f"'conda install -c conda-forge {backend_name}'"
128+
)
129+
raise NotInstalledError(msg)
130+
131+
return plot_cls

0 commit comments

Comments
 (0)