Skip to content

Commit

Permalink
tests: 100% coverage for continuous_opt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Oct 5, 2024
1 parent 2c4d87e commit 91e6db6
Showing 2 changed files with 79 additions and 3 deletions.
3 changes: 1 addition & 2 deletions src/mlrose_ky/opt_probs/continuous_opt.py
Original file line number Diff line number Diff line change
@@ -53,9 +53,8 @@ def __init__(self, length: int, fitness_fn: Any, maximize: bool = True, min_val:
if self.fitness_fn.get_prob_type() not in {"continuous", "either"}:
raise ValueError(
"fitness_fn must have problem type 'continuous' or 'either'. "
"Define problem as DiscreteOpt or use an appropriate fitness function."
"Use an appropriate fitness function, or use DiscreteOpt instead."
)

if max_val <= min_val:
raise ValueError("max_val must be greater than min_val.")
if step <= 0:
79 changes: 78 additions & 1 deletion tests/test_opt_probs/test_continous_opt.py
Original file line number Diff line number Diff line change
@@ -3,17 +3,43 @@
# Author: Genevieve Hayes (modified by Kyle Nakamura)
# License: BSD 3-clause

import re

import numpy as np
import pytest

from mlrose_ky.opt_probs import ContinuousOpt
from mlrose_ky.fitness import OneMax
from mlrose_ky.fitness import OneMax, CustomFitness
from mlrose_ky.neural import NetworkWeights
from mlrose_ky.neural.activation import identity


class TestContinuousOpt:
"""Tests for ContinuousOpt class."""

def test_continuous_opt_invalid_parameters(self):
"""Test initialization with invalid parameters."""

# noinspection PyMissingOrEmptyDocstring
def custom_fitness_fn(_):
return 0

with pytest.raises(
ValueError,
match="fitness_fn must have problem type 'continuous' or 'either'. "
"Use an appropriate fitness function, or use DiscreteOpt instead.",
):
_ = ContinuousOpt(5, CustomFitness(custom_fitness_fn, problem_type="tsp"))

with pytest.raises(ValueError, match="max_val must be greater than min_val."):
_ = ContinuousOpt(5, CustomFitness(custom_fitness_fn), maximize=False, min_val=1, max_val=0)

with pytest.raises(ValueError, match="step size must be positive."):
_ = ContinuousOpt(5, CustomFitness(custom_fitness_fn), maximize=False, min_val=0, max_val=1, step=-0.1)

with pytest.raises(ValueError, match=re.escape("step size must be less than (max_val - min_val).")):
_ = ContinuousOpt(5, CustomFitness(custom_fitness_fn), step=100)

def test_calculate_updates(self):
"""Test calculate_updates method"""
X = np.array([[0, 1, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1], [1, 0, 0, 0]])
@@ -130,3 +156,54 @@ def test_update_state_outside_range(self):
updated = problem.update_state(y)
z = np.array([2, 0, 5, 0, 5])
assert np.array_equal(updated, z)

def test_random_pop_invalid_pop_size(self):
"""Test random_pop method with invalid pop_size."""
problem = ContinuousOpt(5, OneMax())
with pytest.raises(ValueError, match="pop_size must be a positive integer."):
problem.random_pop(0)
with pytest.raises(ValueError, match="pop_size must be a positive integer."):
problem.random_pop(-10)
with pytest.raises(ValueError, match="pop_size must be a positive integer."):
# noinspection PyTypeChecker
problem.random_pop(2.5) # Non-integer pop_size

def test_reproduce_invalid_parent_lengths(self):
"""Test reproduce method with mismatched parent lengths."""
problem = ContinuousOpt(5, OneMax())
parent_1 = np.zeros(5)
parent_2 = np.zeros(4) # Invalid length
with pytest.raises(ValueError, match="Lengths of parents must match problem length."):
problem.reproduce(parent_1, parent_2)

def test_reproduce_invalid_mutation_prob(self):
"""Test reproduce method with invalid mutation_prob."""
problem = ContinuousOpt(5, OneMax())
parent_1 = np.zeros(5)
parent_2 = np.zeros(5)
with pytest.raises(ValueError, match="mutation_prob must be between 0 and 1."):
problem.reproduce(parent_1, parent_2, mutation_prob=-0.1)
with pytest.raises(ValueError, match="mutation_prob must be between 0 and 1."):
problem.reproduce(parent_1, parent_2, mutation_prob=1.1)

def test_update_state_invalid_updates_length(self):
"""Test update_state method with invalid length of updates."""
problem = ContinuousOpt(5, OneMax())
x = np.array([0, 1, 2, 3, 4])
problem.set_state(x)
updates = np.array([1, 2, 3]) # Invalid length
with pytest.raises(ValueError, match="Length of updates must match problem length."):
problem.update_state(updates)

def test_reproduce_length_one(self):
"""Test reproduce method when length of problem is 1."""
problem = ContinuousOpt(1, OneMax()) # Problem with length 1
parent_1 = np.array([0])
parent_2 = np.array([1])

# Since the problem length is 1, this will trigger the else block
child = problem.reproduce(parent_1, parent_2)

# Check if child is either parent_1 or parent_2 (since it's length 1)
assert len(child) == 1
assert np.array_equal(child, parent_1) or np.array_equal(child, parent_2)

0 comments on commit 91e6db6

Please sign in to comment.