From 4beebd5a7d6efeca6c4f328f28699813b4a20272 Mon Sep 17 00:00:00 2001 From: Kyle Nakamura Date: Fri, 4 Oct 2024 10:00:34 -0700 Subject: [PATCH] tests: 100% coverage for .algorithms.sa --- src/mlrose_ky/algorithms/sa.py | 9 ++++---- tests/test_algorithms/test_sa.py | 38 ++++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/mlrose_ky/algorithms/sa.py b/src/mlrose_ky/algorithms/sa.py index 53397458..c90eab9d 100644 --- a/src/mlrose_ky/algorithms/sa.py +++ b/src/mlrose_ky/algorithms/sa.py @@ -99,8 +99,10 @@ def simulated_annealing( Notes ----- - - The `state_fitness_callback` function is also called before the optimization loop starts (iteration 0) with the initial state and fitness values. - - The simulated annealing algorithm probabilistically accepts worse states as it explores the solution space, with the probability decreasing over time according to the `schedule`. + - The `state_fitness_callback` function is also called before the optimization loop starts (iteration 0) with the initial + state and fitness values. + - The simulated annealing algorithm probabilistically accepts worse states as it explores the solution + space, with the probability decreasing over time according to the `schedule`. References ---------- @@ -144,12 +146,11 @@ def simulated_annealing( ) if not continue_iterating: # Early termination as per callback request - return (problem.get_state(), problem.get_maximize() * problem.get_fitness(), np.asarray(fitness_curve) if curve else None) + return problem.get_state(), problem.get_maximize() * problem.get_fitness(), np.asarray(fitness_curve) if curve else None # Main optimization loop attempts = 0 iters = 0 - continue_iterating = True while attempts < max_attempts and iters < max_iters: # Evaluate the temperature at the current iteration temp = schedule.evaluate(iters) diff --git a/tests/test_algorithms/test_sa.py b/tests/test_algorithms/test_sa.py index 5930f396..189311fe 100644 --- a/tests/test_algorithms/test_sa.py +++ b/tests/test_algorithms/test_sa.py @@ -84,6 +84,7 @@ def test_simulated_annealing_with_callback(self): problem = DiscreteOpt(5, OneMax()) # Define a callback function + # noinspection PyMissingOrEmptyDocstring def callback_function(iteration, attempt, done, state, fitness, fitness_evaluations, curve, user_data): # Record the iteration number user_data["iterations"].append(iteration) @@ -109,7 +110,6 @@ def evaluate(t): return 0.0 schedule = ZeroTempSchedule() - best_state, best_fitness, fitness_curve = simulated_annealing(problem, schedule=schedule, random_state=SEED, curve=True) # Since temperature becomes zero, the loop should terminate early @@ -147,7 +147,7 @@ def callback_function(iteration, attempt, done, state, fitness, fitness_evaluati return True max_attempts = 3 - user_data = {"attempts": []} + callback_data = {"attempts": []} # Since the initial state is already optimal, no better state will be found, # so attempts will increase until max_attempts is reached. @@ -157,8 +157,38 @@ def callback_function(iteration, attempt, done, state, fitness, fitness_evaluati max_attempts=max_attempts, random_state=SEED, state_fitness_callback=callback_function, - callback_user_info=user_data, + callback_user_info=callback_data, ) # Check that max_attempts was reached - assert max(user_data["attempts"]) == max_attempts + assert max(callback_data["attempts"]) == max_attempts + + def test_simulated_annealing_callback_early_termination(self): + """Test simulated_annealing with early termination via state_fitness_callback when callback_user_info is None""" + problem = DiscreteOpt(5, OneMax()) + + # noinspection PyMissingOrEmptyDocstring + def callback_function(iteration, attempt, done, state, fitness, fitness_evaluations, curve, user_data): + return False # Terminate immediately + + best_state, best_fitness, _ = simulated_annealing(problem, random_state=SEED, state_fitness_callback=callback_function) + + # Verify that the algorithm terminated immediately + assert problem.current_iteration == 0 + assert isinstance(best_state, np.ndarray) + assert isinstance(best_fitness, float) + + def test_simulated_annealing_problem_can_stop(self): + """Test simulated_annealing where problem.can_stop() returns True""" + + class TestProblem(DiscreteOpt): + def can_stop(self): + return True + + problem = TestProblem(5, OneMax()) + best_state, best_fitness, _ = simulated_annealing(problem, random_state=SEED) + + # Verify that the algorithm terminated after the first iteration + assert problem.current_iteration == 1 + assert isinstance(best_state, np.ndarray) + assert isinstance(best_fitness, float)