Skip to content

Commit

Permalink
tests: 100% coverage for knapsack_opt; refactor exceptions for knapsack
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Oct 4, 2024
1 parent 4beebd5 commit 39f5a64
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 47 deletions.
16 changes: 8 additions & 8 deletions src/mlrose_ky/generators/knapsack_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ def generate(
If any parameter is not of the expected type or value.
"""
if not isinstance(seed, int):
raise ValueError(f"Seed must be an integer. Got {seed}")
raise ValueError(f"Seed must be an integer. Got {seed}.")
if not isinstance(number_of_item_types, int) or number_of_item_types <= 0:
raise ValueError(f"Number of item types must be a positive integer. Got {number_of_item_types}")
raise ValueError(f"Number of item types must be a positive integer. Got {number_of_item_types}.")
if not isinstance(max_item_count, int) or max_item_count <= 0:
raise ValueError(f"Max item count must be a positive integer. Got {max_item_count}")
raise ValueError(f"Max item count must be a positive integer. Got {max_item_count}.")
if not isinstance(max_weight_per_item, int) or max_weight_per_item <= 0:
raise ValueError(f"Max weight per item must be a positive integer. Got {max_weight_per_item}")
raise ValueError(f"Max weight per item must be a positive integer. Got {max_weight_per_item}.")
if not isinstance(max_value_per_item, int) or max_value_per_item <= 0:
raise ValueError(f"Max value per item must be a positive integer. Got {max_value_per_item}")
if not isinstance(max_weight_pct, float) or not (0 <= max_weight_pct <= 1):
raise ValueError(f"Max weight percentage must be a float between 0 and 1. Got {max_weight_pct}")
raise ValueError(f"Max value per item must be a positive integer. Got {max_value_per_item}.")
if not isinstance(max_weight_pct, float) or not (0 <= max_weight_pct < 1):
raise ValueError(f"max_weight_pct must be between 0 (inclusive) and 1 (exclusive), got {max_weight_pct}.")
if not isinstance(multiply_by_max_item_count, bool):
raise ValueError(f"multiply_by_max_item_count must be a boolean. Got {multiply_by_max_item_count}")
raise ValueError(f"multiply_by_max_item_count must be a boolean. Got {multiply_by_max_item_count}.")

np.random.seed(seed)

Expand Down
6 changes: 3 additions & 3 deletions src/mlrose_ky/opt_probs/knapsack_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def __init__(
mutator: "ChangeOneMutator" = None,
multiply_by_max_item_count: bool = False,
):
if fitness_fn is None and (weights is None and values is None):
if max_weight_pct <= 0 or max_weight_pct > 1.0:
raise ValueError(f"max_weight_pct must be between 0 (inclusive) and 1 (exclusive), got {max_weight_pct}.")
if fitness_fn is None and ((weights is None or not len(weights)) or (values is None or not len(values))):
raise ValueError("Either fitness_fn or both weights and values must be specified.")

if length is None:
if weights is not None:
length = len(weights)
elif values is not None:
length = len(values)
elif fitness_fn is not None:
length = len(fitness_fn.weights)

Expand Down
30 changes: 15 additions & 15 deletions tests/test_generators/test_knapsack_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,55 +37,55 @@ def test_generate_invalid_seed(self):
"""Test generate method raises ValueError when SEED is not an integer."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed="not_an_int")
assert str(excinfo.value) == "Seed must be an integer. Got not_an_int"
assert str(excinfo.value) == "Seed must be an integer. Got not_an_int."

def test_generate_invalid_number_of_item_types(self):
"""Test generate method raises ValueError when number_of_item_types is invalid."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, number_of_item_types=0)
assert str(excinfo.value) == "Number of item types must be a positive integer. Got 0"
assert str(excinfo.value) == "Number of item types must be a positive integer. Got 0."

def test_generate_invalid_max_item_count(self):
"""Test generate method raises ValueError when max_item_count is invalid."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, max_item_count=-1)
assert str(excinfo.value) == "Max item count must be a positive integer. Got -1"
assert str(excinfo.value) == "Max item count must be a positive integer. Got -1."

def test_generate_invalid_max_weight_per_item(self):
"""Test generate method raises ValueError when max_weight_per_item is invalid."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, max_weight_per_item=0)
assert str(excinfo.value) == "Max weight per item must be a positive integer. Got 0"
assert str(excinfo.value) == "Max weight per item must be a positive integer. Got 0."

def test_generate_invalid_max_value_per_item(self):
"""Test generate method raises ValueError when max_value_per_item is invalid."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, max_value_per_item=-10)
assert str(excinfo.value) == "Max value per item must be a positive integer. Got -10"

def test_generate_default_parameters(self):
"""Test generate method with default parameters."""
problem = KnapsackGenerator.generate(seed=SEED)

assert problem.length == 10
assert problem.max_val == 5
assert str(excinfo.value) == "Max value per item must be a positive integer. Got -10."

def test_generate_invalid_max_weight_percentage(self):
"""Test generate method raises ValueError when max_weight_percentage is invalid."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, max_weight_pct=1.5)
assert str(excinfo.value) == "Max weight percentage must be a float between 0 and 1. Got 1.5"
assert str(excinfo.value) == f"max_weight_pct must be between 0 (inclusive) and 1 (exclusive), got 1.5."

def test_generate_invalid_multiply_by_max_item_count(self):
"""Test generate method raises ValueError when multiply_by_max_item_count is not a boolean."""
with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, multiply_by_max_item_count="yes")
assert str(excinfo.value) == "multiply_by_max_item_count must be a boolean. Got yes"
assert str(excinfo.value) == "multiply_by_max_item_count must be a boolean. Got yes."

def test_generate_max_weight_percentage_zero(self):
"""Test generate method with max_weight_percentage set to 0"""
max_weight_percentage = 0.0

with pytest.raises(ValueError) as excinfo:
KnapsackGenerator.generate(seed=SEED, max_weight_pct=max_weight_percentage)
assert str(excinfo.value) == "max_weight_pct must be between 0 and 1."
assert str(excinfo.value) == "max_weight_pct must be between 0 (inclusive) and 1 (exclusive), got 0.0."

def test_generate_default_parameters(self):
"""Test generate method with default parameters."""
problem = KnapsackGenerator.generate(seed=SEED)

assert problem.length == 10
assert problem.max_val == 5
57 changes: 36 additions & 21 deletions tests/test_opt_probs/test_knapsack_opt.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
"""Unit tests for opt_probs/knapsack_opt.py"""

import re

# Author: Genevieve Hayes (modified by Kyle Nakamura)
# License: BSD 3-clause

import numpy as np
import pytest

from mlrose_ky import FlipFlop, Knapsack
from mlrose_ky.opt_probs import KnapsackOpt


class TestKnapsackOpt:
"""Tests for KnapsackOpt class."""

def test_initialization(self):
"""Test initialization of KnapsackOpt"""
def test_knapsack_opt_invalid_weights_values(self):
"""Test that KnapsackOpt raises ValueError when either weights or values is None."""
weights = [10, 5, 2]
values = [1, 2, 3]
problem = KnapsackOpt(weights=weights, values=values, max_weight_pct=0.5)
with pytest.raises(ValueError, match="Either fitness_fn or both weights and values must be specified."):
_ = KnapsackOpt(weights=weights)
with pytest.raises(ValueError, match="Either fitness_fn or both weights and values must be specified."):
_ = KnapsackOpt(values=values)
with pytest.raises(ValueError, match="Either fitness_fn or both weights and values must be specified."):
_ = KnapsackOpt(weights=[], values=[])

def test_knapsack_opt_invalid_max_weight_pct(self):
"""Test that KnapsackOpt raises ValueError when max_weight_pct has an invalid value."""
max_weight_pct = -0.1
with pytest.raises(
ValueError, match=re.escape(f"max_weight_pct must be between 0 (inclusive) and 1 (exclusive), got {max_weight_pct}.")
):
_ = KnapsackOpt(weights=[1], values=[1], max_weight_pct=max_weight_pct)

def test_knapsack_opt_with_weights_and_values(self):
"""Test that KnapsackOpt can be initialized with weights and values."""
weights = [10, 5, 2]
values = [1, 2, 3]
problem = KnapsackOpt(weights=weights, values=values)
assert problem.length == 3
assert problem.max_val == 2

def test_set_state(self):
def test_knapsack_opt_initialization_with_fitness_fn(self):
"""Test that KnapsackOpt can be initialized with a Knapsack fitness function."""
fitness_fn = Knapsack(weights=[10, 5, 2], values=[1, 2, 3])
problem = KnapsackOpt(fitness_fn=fitness_fn)
assert problem.length == 3
assert problem.max_val == 2

def test_knapsack_opt_set_state(self):
"""Test set_state method"""
weights = [10, 5, 2]
values = [1, 2, 3]
Expand All @@ -28,7 +58,7 @@ def test_set_state(self):
problem.set_state(state)
assert np.array_equal(problem.get_state().tolist(), state)

def test_eval_fitness(self):
def test_knapsack_opt_eval_fitness(self):
"""Test eval_fitness method"""
weights = [10, 5, 2, 8, 15]
values = [1, 2, 3, 4, 5]
Expand All @@ -37,26 +67,11 @@ def test_eval_fitness(self):
fitness = problem.eval_fitness(state)
assert fitness == 11.0 # Assuming the fitness function calculates correctly

def test_set_population(self):
def test_knapsack_opt_set_population(self):
"""Test set_population method"""
weights = [10, 5, 2]
values = [1, 2, 3]
problem = KnapsackOpt(weights=weights, values=values, max_weight_pct=0.5)
pop = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
problem.set_population(pop)
assert np.array_equal(problem.get_population().tolist(), pop)

def test_edge_cases(self):
"""Test edge cases for KnapsackOpt"""
# Test with empty weights and values
try:
KnapsackOpt(weights=[], values=[], max_weight_pct=0.5)
except Exception as e:
assert str(e) == "fitness_fn or both weights and values must be specified."

# Test with invalid max_weight_pct
try:
KnapsackOpt(weights=[1], values=[1], max_weight_pct=-0.1)
assert False, "Expected an exception for invalid max_weight_pct"
except Exception as e:
assert str(e) == "max_weight_pct must be between 0 and 1."

0 comments on commit 39f5a64

Please sign in to comment.