Skip to content

Commit 84be118

Browse files
authored
Polish the algo selection tool (#550)
1 parent 058c79e commit 84be118

File tree

14 files changed

+198
-72
lines changed

14 files changed

+198
-72
lines changed

.tools/create_algo_selection_code.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _get_algorithms_in_module(module: ModuleType) -> dict[str, Type[Algorithm]]:
109109
}
110110
algos = {}
111111
for candidate in candidate_dict.values():
112-
name = candidate.__algo_info__.name
112+
name = candidate.algo_info.name
113113
if issubclass(candidate, Algorithm) and candidate is not Algorithm:
114114
algos[name] = candidate
115115
return algos
@@ -119,47 +119,47 @@ def _get_algorithms_in_module(module: ModuleType) -> dict[str, Type[Algorithm]]:
119119
# Functions to filter algorithms by selectors
120120
# ======================================================================================
121121
def _is_gradient_based(algo: Type[Algorithm]) -> bool:
122-
return algo.__algo_info__.needs_jac # type: ignore
122+
return algo.algo_info.needs_jac # type: ignore
123123

124124

125125
def _is_gradient_free(algo: Type[Algorithm]) -> bool:
126126
return not _is_gradient_based(algo)
127127

128128

129129
def _is_global(algo: Type[Algorithm]) -> bool:
130-
return algo.__algo_info__.is_global # type: ignore
130+
return algo.algo_info.is_global # type: ignore
131131

132132

133133
def _is_local(algo: Type[Algorithm]) -> bool:
134134
return not _is_global(algo)
135135

136136

137137
def _is_bounded(algo: Type[Algorithm]) -> bool:
138-
return algo.__algo_info__.supports_bounds # type: ignore
138+
return algo.algo_info.supports_bounds # type: ignore
139139

140140

141141
def _is_linear_constrained(algo: Type[Algorithm]) -> bool:
142-
return algo.__algo_info__.supports_linear_constraints # type: ignore
142+
return algo.algo_info.supports_linear_constraints # type: ignore
143143

144144

145145
def _is_nonlinear_constrained(algo: Type[Algorithm]) -> bool:
146-
return algo.__algo_info__.supports_nonlinear_constraints # type: ignore
146+
return algo.algo_info.supports_nonlinear_constraints # type: ignore
147147

148148

149149
def _is_scalar(algo: Type[Algorithm]) -> bool:
150-
return algo.__algo_info__.solver_type == AggregationLevel.SCALAR # type: ignore
150+
return algo.algo_info.solver_type == AggregationLevel.SCALAR # type: ignore
151151

152152

153153
def _is_least_squares(algo: Type[Algorithm]) -> bool:
154-
return algo.__algo_info__.solver_type == AggregationLevel.LEAST_SQUARES # type: ignore
154+
return algo.algo_info.solver_type == AggregationLevel.LEAST_SQUARES # type: ignore
155155

156156

157157
def _is_likelihood(algo: Type[Algorithm]) -> bool:
158-
return algo.__algo_info__.solver_type == AggregationLevel.LIKELIHOOD # type: ignore
158+
return algo.algo_info.solver_type == AggregationLevel.LIKELIHOOD # type: ignore
159159

160160

161161
def _is_parallel(algo: Type[Algorithm]) -> bool:
162-
return algo.__algo_info__.supports_parallelism # type: ignore
162+
return algo.algo_info.supports_parallelism # type: ignore
163163

164164

165165
def _get_filters() -> dict[str, Callable[[Type[Algorithm]], bool]]:
@@ -385,27 +385,32 @@ def _all(self) -> list[Type[Algorithm]]:
385385
def _available(self) -> list[Type[Algorithm]]:
386386
_all = self._all()
387387
return [
388-
a for a in _all if a.__algo_info__.is_available # type: ignore
388+
a for a in _all if a.algo_info.is_available # type: ignore
389389
]
390390
391391
@property
392-
def All(self) -> list[str]:
393-
return [a.__algo_info__.name for a in self._all()] # type: ignore
392+
def All(self) -> list[Type[Algorithm]]:
393+
return self._all()
394394
395395
@property
396-
def Available(self) -> list[str]:
397-
return [a.__algo_info__.name for a in self._available()] # type: ignore
396+
def Available(self) -> list[Type[Algorithm]]:
397+
return self._available()
398+
399+
@property
400+
def AllNames(self) -> list[str]:
401+
return [str(a.name) for a in self._all()]
402+
403+
@property
404+
def AvailableNames(self) -> list[str]:
405+
return [str(a.name) for a in self._available()]
398406
399407
@property
400408
def _all_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
401-
return {a.__algo_info__.name: a for a in self._all()} # type: ignore
409+
return {str(a.name): a for a in self._all()}
402410
403411
@property
404412
def _available_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
405-
return {
406-
a.__algo_info__.name: a # type: ignore
407-
for a in self._available()
408-
}
413+
return {str(a.name): a for a in self._available()}
409414
410415
""")
411416
return out

CHANGES.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,27 @@ This is a record of all past optimagic releases and what went into them in rever
44
chronological order. We follow [semantic versioning](https://semver.org/) and all
55
releases are available on [Anaconda.org](https://anaconda.org/optimagic-dev/optimagic).
66

7-
Following the [scientific python guidelines](https://scientific-python.org/specs/spec-0000/)
8-
we drop the official support for Python 3.9.
97

8+
## 0.5.1
9+
10+
This is a minor release that introduces the new algorithm selection tool and several
11+
small improvements.
12+
13+
To learn more about the algorithm selection feature check out the following resources:
14+
15+
- [How to specify and configure algorithms](https://optimagic.readthedocs.io/en/latest/how_to/how_to_specify_algorithm_and_algo_options.html)
16+
- [How to select local optimizers](https://optimagic.readthedocs.io/en/latest/how_to/how_to_algorithm_selection.html)
17+
18+
- {gh}`549` Add support for Python 3.13 ({ghuser}`timmens`)
19+
- {gh}`550` and {gh}`534` implement the new algorithm selection tool ({ghuser}`janosg`)
20+
- {gh}`548` and {gh}`531` improve the documentation ({ghuser}`ChristianZimpelmann`)
21+
- {gh}`544` Adjusts the results processing of the nag optimizers to be compatible
22+
with the latest releases ({ghuser}`timmens`)
23+
- {gh}`543` Adds support for numpy 2.x ({ghuser}`timmens`)
24+
- {gh}`536` Adds a how-to guide for choosing local optimizers ({ghuser}`mpetrosian`)
25+
- {gh}`535` Allows algorithm classes and instances in estimation functions
26+
({ghuser}`timmens`)
27+
- {gh}`532` Makes several small improvements to the documentation.
1028

1129
## 0.5.0
1230

src/estimagic/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
1-
import contextlib
21
import warnings
32
from dataclasses import dataclass
43

5-
try:
6-
import pdbp # noqa: F401
7-
except ImportError:
8-
contextlib.suppress(Exception)
9-
104
from estimagic import utilities
115
from estimagic.bootstrap import BootstrapResult, bootstrap
126
from estimagic.estimate_ml import LikelihoodResult, estimate_ml

src/optimagic/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
from __future__ import annotations
22

3-
import contextlib
4-
5-
try:
6-
import pdbp # noqa: F401
7-
except ImportError:
8-
contextlib.suppress(Exception)
9-
103
from optimagic import constraints, mark, utilities
114
from optimagic.algorithms import algos
125
from optimagic.benchmarking.benchmark_reports import (

src/optimagic/algorithms.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,27 +90,32 @@ def _available(self) -> list[Type[Algorithm]]:
9090
return [
9191
a
9292
for a in _all
93-
if a.__algo_info__.is_available # type: ignore
93+
if a.algo_info.is_available # type: ignore
9494
]
9595

9696
@property
97-
def All(self) -> list[str]:
98-
return [a.__algo_info__.name for a in self._all()] # type: ignore
97+
def All(self) -> list[Type[Algorithm]]:
98+
return self._all()
9999

100100
@property
101-
def Available(self) -> list[str]:
102-
return [a.__algo_info__.name for a in self._available()] # type: ignore
101+
def Available(self) -> list[Type[Algorithm]]:
102+
return self._available()
103+
104+
@property
105+
def AllNames(self) -> list[str]:
106+
return [str(a.name) for a in self._all()]
107+
108+
@property
109+
def AvailableNames(self) -> list[str]:
110+
return [str(a.name) for a in self._available()]
103111

104112
@property
105113
def _all_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
106-
return {a.__algo_info__.name: a for a in self._all()} # type: ignore
114+
return {str(a.name): a for a in self._all()}
107115

108116
@property
109117
def _available_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
110-
return {
111-
a.__algo_info__.name: a # type: ignore
112-
for a in self._available()
113-
}
118+
return {str(a.name): a for a in self._available()}
114119

115120

116121
@dataclass(frozen=True)

src/optimagic/optimization/algorithm.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
import warnings
3-
from abc import ABC, abstractmethod
3+
from abc import ABC, ABCMeta, abstractmethod
44
from dataclasses import dataclass, replace
55
from typing import Any
66

@@ -143,8 +143,38 @@ def __post_init__(self) -> None:
143143
raise TypeError(msg)
144144

145145

146+
class AlgorithmMeta(ABCMeta):
147+
"""Metaclass to get repr, algo_info and name for classes, not just instances."""
148+
149+
def __repr__(self) -> str:
150+
if hasattr(self, "__algo_info__") and self.__algo_info__ is not None:
151+
out = f"om.algos.{self.__algo_info__.name}"
152+
else:
153+
out = self.__class__.__name__
154+
return out
155+
156+
@property
157+
def name(self) -> str:
158+
if hasattr(self, "__algo_info__") and self.__algo_info__ is not None:
159+
out = self.__algo_info__.name
160+
else:
161+
out = self.__class__.__name__
162+
return out
163+
164+
@property
165+
def algo_info(self) -> AlgoInfo:
166+
if not hasattr(self, "__algo_info__") or self.__algo_info__ is None:
167+
msg = (
168+
f"The algorithm {self.name} does not have have the __algo_info__ "
169+
"attribute. Use the `mark.minimizer` decorator to add this attribute."
170+
)
171+
raise AttributeError(msg)
172+
173+
return self.__algo_info__
174+
175+
146176
@dataclass(frozen=True)
147-
class Algorithm(ABC):
177+
class Algorithm(ABC, metaclass=AlgorithmMeta):
148178
@abstractmethod
149179
def _solve_internal_problem(
150180
self, problem: InternalOptimizationProblem, x0: NDArray[np.float64]

src/optimagic/optimizers/pygmo_optimizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from __future__ import annotations
2121

22-
import contextlib
2322
import warnings
2423
from dataclasses import dataclass
2524
from typing import Any, List, Literal
@@ -48,8 +47,10 @@
4847

4948
STOPPING_MAX_ITERATIONS_GENETIC = 250
5049

51-
with contextlib.suppress(ImportError):
50+
try:
5251
import pygmo as pg
52+
except ImportError:
53+
pass
5354

5455

5556
@mark.minimizer(

src/optimagic/optimizers/tao_optimizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module implements the POUNDERs algorithm."""
22

3-
import contextlib
43
import functools
54
from dataclasses import dataclass
65

@@ -23,8 +22,10 @@
2322
from optimagic.typing import AggregationLevel, NonNegativeFloat, PositiveInt
2423
from optimagic.utilities import calculate_trustregion_initial_radius
2524

26-
with contextlib.suppress(ImportError):
25+
try:
2726
from petsc4py import PETSc
27+
except ImportError:
28+
pass
2829

2930

3031
@mark.minimizer(

src/optimagic/visualization/history_plots.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import inspect
12
import itertools
23
from pathlib import Path
4+
from typing import Any
35

46
import numpy as np
57
import plotly.graph_objects as go
68
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
79

810
from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
911
from optimagic.logging.logger import LogReader, SQLiteLogOptions
12+
from optimagic.optimization.algorithm import Algorithm
1013
from optimagic.optimization.history_tools import get_history_arrays
1114
from optimagic.optimization.optimize_result import OptimizeResult
1215
from optimagic.parameters.tree_registry import get_registry
@@ -50,23 +53,7 @@ def criterion_plot(
5053
# Process inputs
5154
# ==================================================================================
5255

53-
if not isinstance(names, list) and names is not None:
54-
names = [names]
55-
56-
if not isinstance(results, dict):
57-
if isinstance(results, list):
58-
names = range(len(results)) if names is None else names
59-
if len(names) != len(results):
60-
raise ValueError("len(results) needs to be equal to len(names).")
61-
results = dict(zip(names, results, strict=False))
62-
else:
63-
name = 0 if names is None else names
64-
if isinstance(name, list):
65-
if len(name) > 1:
66-
raise ValueError("len(results) needs to be equal to len(names).")
67-
else:
68-
name = name[0]
69-
results = {name: results}
56+
results = _harmonize_inputs_to_dict(results, names)
7057

7158
if not isinstance(palette, list):
7259
palette = [palette]
@@ -180,6 +167,46 @@ def criterion_plot(
180167
return fig
181168

182169

170+
def _harmonize_inputs_to_dict(results, names):
171+
"""Convert all valid inputs for results and names to dict[str, OptimizeResult]."""
172+
# convert scalar case to list case
173+
if not isinstance(names, list) and names is not None:
174+
names = [names]
175+
176+
if isinstance(results, OptimizeResult):
177+
results = [results]
178+
179+
if names is not None and len(names) != len(results):
180+
raise ValueError("len(results) needs to be equal to len(names).")
181+
182+
# handle dict case
183+
if isinstance(results, dict):
184+
if names is not None:
185+
results_dict = dict(zip(names, list(results.values()), strict=False))
186+
else:
187+
results_dict = results
188+
189+
# unlabeled iterable of results
190+
else:
191+
names = range(len(results)) if names is None else names
192+
results_dict = dict(zip(names, results, strict=False))
193+
194+
# convert keys to strings
195+
results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()}
196+
197+
return results_dict
198+
199+
200+
def _convert_key_to_str(key: Any) -> str:
201+
if inspect.isclass(key) and issubclass(key, Algorithm):
202+
out = str(key.name)
203+
elif isinstance(key, Algorithm):
204+
out = str(key.name)
205+
else:
206+
out = str(key)
207+
return out
208+
209+
183210
def params_plot(
184211
result,
185212
selector=None,

tests/optimagic/optimization/test_history_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
OPTIMIZERS = []
1818
BOUNDED = []
1919
for name, algo in AVAILABLE_ALGORITHMS.items():
20-
info = algo.__algo_info__
20+
info = algo.algo_info
2121
if not info.disable_history:
2222
if info.supports_parallelism:
2323
OPTIMIZERS.append(name)

0 commit comments

Comments
 (0)