@@ -324,6 +324,37 @@ def test_gen_candidates_torch_timeout_behavior(self):
324324 self .assertFalse (any (issubclass (w .category , OptimizationWarning ) for w in ws ))
325325 self .assertTrue ("Optimization timed out" in logs .output [- 1 ])
326326
327+ def test_gen_candidates_torch_optimizer_with_optimizer_args (self ):
328+ """Test that Adam optimizer is created with the correct learning rate."""
329+ self ._setUp (double = False )
330+ qEI = qExpectedImprovement (self .model , best_f = self .f_best )
331+
332+ # Create a mock optimizer class
333+ mock_optimizer_class = mock .MagicMock ()
334+ mock_optimizer_instance = mock .MagicMock ()
335+ mock_optimizer_class .return_value = mock_optimizer_instance
336+
337+ gen_candidates_torch (
338+ initial_conditions = self .initial_conditions ,
339+ acquisition_function = qEI ,
340+ lower_bounds = 0 ,
341+ upper_bounds = 1 ,
342+ optimizer = mock_optimizer_class , # Pass the mock optimizer directly
343+ options = {
344+ "optimizer_options" : {"lr" : 0.02 , "weight_decay" : 1e-5 },
345+ "stopping_criterion_options" : {"maxiter" : 1 },
346+ },
347+ )
348+
349+ # Verify that the optimizer was called with the correct arguments
350+ mock_optimizer_class .assert_called_once ()
351+ call_args = mock_optimizer_class .call_args
352+ # Check that params argument is present
353+ self .assertIn ("params" , call_args .kwargs )
354+ # Check optimizer options
355+ self .assertEqual (call_args .kwargs ["lr" ], 0.02 )
356+ self .assertEqual (call_args .kwargs ["weight_decay" ], 1e-5 )
357+
327358 def test_gen_candidates_scipy_warns_opt_no_res (self ):
328359 ckwargs = {"dtype" : torch .float , "device" : self .device }
329360
0 commit comments