Skip to content

Commit

Permalink
tests: 100% coverage for .algorithms.sa
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Oct 4, 2024
1 parent 17a9d10 commit 4beebd5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/mlrose_ky/algorithms/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def simulated_annealing(
Notes
-----
- The `state_fitness_callback` function is also called before the optimization loop starts (iteration 0) with the initial state and fitness values.
- The simulated annealing algorithm probabilistically accepts worse states as it explores the solution space, with the probability decreasing over time according to the `schedule`.
- The `state_fitness_callback` function is also called before the optimization loop starts (iteration 0) with the initial
state and fitness values.
- The simulated annealing algorithm probabilistically accepts worse states as it explores the solution
space, with the probability decreasing over time according to the `schedule`.
References
----------
Expand Down Expand Up @@ -144,12 +146,11 @@ def simulated_annealing(
)
if not continue_iterating:
# Early termination as per callback request
return (problem.get_state(), problem.get_maximize() * problem.get_fitness(), np.asarray(fitness_curve) if curve else None)
return problem.get_state(), problem.get_maximize() * problem.get_fitness(), np.asarray(fitness_curve) if curve else None

# Main optimization loop
attempts = 0
iters = 0
continue_iterating = True
while attempts < max_attempts and iters < max_iters:
# Evaluate the temperature at the current iteration
temp = schedule.evaluate(iters)
Expand Down
38 changes: 34 additions & 4 deletions tests/test_algorithms/test_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_simulated_annealing_with_callback(self):
problem = DiscreteOpt(5, OneMax())

# Define a callback function
# noinspection PyMissingOrEmptyDocstring
def callback_function(iteration, attempt, done, state, fitness, fitness_evaluations, curve, user_data):
# Record the iteration number
user_data["iterations"].append(iteration)
Expand All @@ -109,7 +110,6 @@ def evaluate(t):
return 0.0

schedule = ZeroTempSchedule()

best_state, best_fitness, fitness_curve = simulated_annealing(problem, schedule=schedule, random_state=SEED, curve=True)

# Since temperature becomes zero, the loop should terminate early
Expand Down Expand Up @@ -147,7 +147,7 @@ def callback_function(iteration, attempt, done, state, fitness, fitness_evaluati
return True

max_attempts = 3
user_data = {"attempts": []}
callback_data = {"attempts": []}

# Since the initial state is already optimal, no better state will be found,
# so attempts will increase until max_attempts is reached.
Expand All @@ -157,8 +157,38 @@ def callback_function(iteration, attempt, done, state, fitness, fitness_evaluati
max_attempts=max_attempts,
random_state=SEED,
state_fitness_callback=callback_function,
callback_user_info=user_data,
callback_user_info=callback_data,
)

# Check that max_attempts was reached
assert max(user_data["attempts"]) == max_attempts
assert max(callback_data["attempts"]) == max_attempts

def test_simulated_annealing_callback_early_termination(self):
"""Test simulated_annealing with early termination via state_fitness_callback when callback_user_info is None"""
problem = DiscreteOpt(5, OneMax())

# noinspection PyMissingOrEmptyDocstring
def callback_function(iteration, attempt, done, state, fitness, fitness_evaluations, curve, user_data):
return False # Terminate immediately

best_state, best_fitness, _ = simulated_annealing(problem, random_state=SEED, state_fitness_callback=callback_function)

# Verify that the algorithm terminated immediately
assert problem.current_iteration == 0
assert isinstance(best_state, np.ndarray)
assert isinstance(best_fitness, float)

def test_simulated_annealing_problem_can_stop(self):
"""Test simulated_annealing where problem.can_stop() returns True"""

class TestProblem(DiscreteOpt):
def can_stop(self):
return True

problem = TestProblem(5, OneMax())
best_state, best_fitness, _ = simulated_annealing(problem, random_state=SEED)

# Verify that the algorithm terminated after the first iteration
assert problem.current_iteration == 1
assert isinstance(best_state, np.ndarray)
assert isinstance(best_fitness, float)

0 comments on commit 4beebd5

Please sign in to comment.