@@ -532,7 +532,9 @@ def gen_candidates_torch(
532532 optimizer (Optimizer): The pytorch optimizer to use to perform
533533 candidate search.
534534 options: Options used to control the optimization. Includes
535- maxiter: Maximum number of iterations
535+ optimizer_options: Dict of additional options to pass to the optimizer
536+ (e.g. lr, weight_decay)
537+ stopping_criterion_options: Dict of options for the stopping criterion.
536538 callback: A callback function accepting the current iteration, loss,
537539 and gradients as arguments. This function is executed after computing
538540 the loss and gradients, but before calling the optimizer.
@@ -559,11 +561,11 @@ def gen_candidates_torch(
559561 >>> qEI, bounds, q=3, num_restarts=25, raw_samples=500
560562 >>> )
561563 >>> batch_candidates, batch_acq_values = gen_candidates_torch(
562- initial_conditions=Xinit,
563- acquisition_function=qEI,
564- lower_bounds=bounds[0],
565- upper_bounds=bounds[1],
566- )
564+ initial_conditions=Xinit,
565+ acquisition_function=qEI,
566+ lower_bounds=bounds[0],
567+ upper_bounds=bounds[1],
568+ )
567569 """
568570 start_time = time .monotonic ()
569571 options = options or {}
@@ -580,11 +582,17 @@ def gen_candidates_torch(
580582 [i for i in range (clamped_candidates .shape [- 1 ]) if i not in fixed_features ],
581583 ]
582584 clamped_candidates = clamped_candidates .requires_grad_ (True )
583- _optimizer = optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
585+
586+ # Extract optimizer-specific options from the options dict
587+ optimizer_options = options .pop ("optimizer_options" , {})
588+ stopping_criterion_options = options .pop ("stopping_criterion_options" , {})
589+
590+ optimizer_options ["lr" ] = optimizer_options .get ("lr" , 0.025 )
591+ _optimizer = optimizer (params = [clamped_candidates ], ** optimizer_options )
584592
585593 i = 0
586594 stop = False
587- stopping_criterion = ExpMAStoppingCriterion (** options )
595+ stopping_criterion = ExpMAStoppingCriterion (** stopping_criterion_options )
588596 while not stop :
589597 i += 1
590598 with torch .no_grad ():
0 commit comments