Skip to content

Commit d60b99d

Browse files
authored
Add parameter validation (#577)
* Add paramter validation * fix: pre-commit * fix: xi * fix: docstrings * add tests
1 parent 696fec6 commit d60b99d

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

bayes_opt/acquisition.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,14 @@ def __init__(
452452
if kappa < 0:
453453
error_msg = "kappa must be greater than or equal to 0."
454454
raise ValueError(error_msg)
455+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
456+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
457+
raise ValueError(error_msg)
458+
if exploration_decay_delay is not None and (
459+
not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0
460+
):
461+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
462+
raise ValueError(error_msg)
455463

456464
super().__init__(random_state=random_state)
457465
self.kappa = kappa
@@ -604,6 +612,18 @@ def __init__(
604612
exploration_decay_delay: int | None = None,
605613
random_state: int | RandomState | None = None,
606614
) -> None:
615+
if xi < 0:
616+
error_msg = "xi must be greater than or equal to 0."
617+
raise ValueError(error_msg)
618+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
619+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
620+
raise ValueError(error_msg)
621+
if exploration_decay_delay is not None and (
622+
not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0
623+
):
624+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
625+
raise ValueError(error_msg)
626+
607627
super().__init__(random_state=random_state)
608628
self.xi = xi
609629
self.exploration_decay = exploration_decay
@@ -766,6 +786,7 @@ class ExpectedImprovement(AcquisitionFunction):
766786
Decay rate for xi. If None, no decay is applied.
767787
768788
exploration_decay_delay : int, default None
789+
Delay for decay. If None, decay is applied from the start.
769790
770791
random_state : int, RandomState, default None
771792
Set the random state for reproducibility.
@@ -778,6 +799,18 @@ def __init__(
778799
exploration_decay_delay: int | None = None,
779800
random_state: int | RandomState | None = None,
780801
) -> None:
802+
if xi < 0:
803+
error_msg = "xi must be greater than or equal to 0."
804+
raise ValueError(error_msg)
805+
if exploration_decay is not None and not (0 < exploration_decay <= 1):
806+
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
807+
raise ValueError(error_msg)
808+
if exploration_decay_delay is not None and (
809+
not isinstance(exploration_decay_delay, int) or exploration_decay_delay < 0
810+
):
811+
error_msg = "exploration_decay_delay must be an integer greater than or equal to 0."
812+
raise ValueError(error_msg)
813+
781814
super().__init__(random_state=random_state)
782815
self.xi = xi
783816
self.exploration_decay = exploration_decay

bayes_opt/bayesian_optimization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ class BayesianOptimization:
5555
Dictionary with parameters names as keys and a tuple with minimum
5656
and maximum values.
5757
58+
acquisition_function: AcquisitionFunction, optional(default=None)
59+
The acquisition function to use for suggesting new points to evaluate.
60+
If None, defaults to UpperConfidenceBound for unconstrained problems
61+
and ExpectedImprovement for constrained problems.
62+
5863
constraint: NonlinearConstraint.
5964
Note that the names of arguments of the constraint function and of
6065
f need to be the same.

tests/test_acquisition.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,66 @@ def test_upper_confidence_bound_invalid_kappa_error(kappa: float):
377377
acquisition.UpperConfidenceBound(kappa=kappa)
378378

379379

380+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
381+
def test_upper_confidence_bound_invalid_exploration_decay_error(exploration_decay: float):
382+
with pytest.raises(
383+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
384+
):
385+
acquisition.UpperConfidenceBound(kappa=1.0, exploration_decay=exploration_decay)
386+
387+
388+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
389+
def test_upper_confidence_bound_invalid_exploration_decay_delay_error(exploration_decay_delay):
390+
with pytest.raises(
391+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
392+
):
393+
acquisition.UpperConfidenceBound(kappa=1.0, exploration_decay_delay=exploration_decay_delay)
394+
395+
396+
@pytest.mark.parametrize("xi", [-0.1, -1.0, -np.inf])
397+
def test_probability_of_improvement_invalid_xi_error(xi: float):
398+
with pytest.raises(ValueError, match="xi must be greater than or equal to 0."):
399+
acquisition.ProbabilityOfImprovement(xi=xi)
400+
401+
402+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
403+
def test_probability_of_improvement_invalid_exploration_decay_error(exploration_decay: float):
404+
with pytest.raises(
405+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
406+
):
407+
acquisition.ProbabilityOfImprovement(xi=0.01, exploration_decay=exploration_decay)
408+
409+
410+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
411+
def test_probability_of_improvement_invalid_exploration_decay_delay_error(exploration_decay_delay):
412+
with pytest.raises(
413+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
414+
):
415+
acquisition.ProbabilityOfImprovement(xi=0.01, exploration_decay_delay=exploration_decay_delay)
416+
417+
418+
@pytest.mark.parametrize("xi", [-0.1, -1.0, -np.inf])
419+
def test_expected_improvement_invalid_xi_error(xi: float):
420+
with pytest.raises(ValueError, match="xi must be greater than or equal to 0."):
421+
acquisition.ExpectedImprovement(xi=xi)
422+
423+
424+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
425+
def test_expected_improvement_invalid_exploration_decay_error(exploration_decay: float):
426+
with pytest.raises(
427+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
428+
):
429+
acquisition.ExpectedImprovement(xi=0.01, exploration_decay=exploration_decay)
430+
431+
432+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
433+
def test_expected_improvement_invalid_exploration_decay_delay_error(exploration_decay_delay):
434+
with pytest.raises(
435+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
436+
):
437+
acquisition.ExpectedImprovement(xi=0.01, exploration_decay_delay=exploration_decay_delay)
438+
439+
380440
def verify_optimizers_match(optimizer1, optimizer2):
381441
"""Helper function to verify two optimizers match."""
382442
assert len(optimizer1.space) == len(optimizer2.space)

0 commit comments

Comments
 (0)