Skip to content

Commit

Permalink
tests: 100% coverage for all .generators
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Oct 5, 2024
1 parent 9598c45 commit 1b2e91e
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 51 deletions.
14 changes: 1 addition & 13 deletions src/mlrose_ky/generators/max_k_color_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,8 @@ def generate(
connected_nodes = sorted(np.random.choice(valid_other_nodes, count, replace=False))
node_connections[node] = [(node, other) for other in connected_nodes]

# Ensure graph connectivity
# Create the graph and ensure connectivity (node_connection_counts >= 1 guarantees that each node has at least one connection)
graph = nx.Graph()
graph.add_edges_from([edge for edges in node_connections.values() for edge in edges])

for node in nodes:
unreachable = [
(node, other) if node < other else (other, node) for other in nodes if other not in nx.bfs_tree(graph, node).nodes()
]
for start, end in unreachable:
graph.add_edge(start, end)
remaining_unreachable = len(
[(node, other) if node < other else (other, node) for other in nodes if other not in nx.bfs_tree(graph, node).nodes()]
)
if remaining_unreachable == 0:
break

return MaxKColorOpt(edges=list(graph.edges()), length=number_of_nodes, maximize=maximize, max_colors=max_colors, source_graph=graph)
7 changes: 6 additions & 1 deletion src/mlrose_ky/generators/tsp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def generate(number_of_cities: int, area_width: int = 250, area_height: int = 25
# Calculate distances between all pairs of cities
distances = TSPGenerator.get_distances(coords, truncate=False)

# Create a graph with the calculated distances
# Create an empty graph
graph = nx.Graph()

# Add nodes to the graph
graph.add_nodes_from(range(number_of_cities))

# Add edges with calculated distances
for a, b, distance in distances:
graph.add_edge(a, b, length=int(round(distance)))

Expand Down
100 changes: 75 additions & 25 deletions tests/test_generators/test_max_k_color_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Unit tests for generators/"""

import networkx as nx
import pytest

from tests.globals import SEED

import mlrose_ky
from mlrose_ky.generators import MaxKColorGenerator


Expand All @@ -13,39 +13,58 @@ class TestMaxKColorGenerator:

def test_generate_negative_max_colors(self):
"""Test generate method raises ValueError when max_colors is a negative integer."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Max colors must be a positive integer or None. Got -3"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10, max_colors=-3)
assert str(excinfo.value) == "Max colors must be a positive integer or None. Got -3"

def test_generate_non_integer_max_colors(self):
"""Test generate method raises ValueError when max_colors is a non-integer value."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Max colors must be a positive integer or None. Got five"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10, max_connections_per_node=3, max_colors="five")
assert str(excinfo.value) == "Max colors must be a positive integer or None. Got five"

def test_generate_seed_float(self):
"""Test generate method raises ValueError when SEED is a float."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Seed must be an integer. Got 1.5"):
MaxKColorGenerator.generate(seed=1.5)
assert str(excinfo.value) == "Seed must be an integer. Got 1.5"

def test_generate_float_number_of_nodes(self):
"""Test generate method raises ValueError when number_of_nodes is a float."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Number of nodes must be a positive integer. Got 10.5"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10.5)
assert str(excinfo.value) == "Number of nodes must be a positive integer. Got 10.5"

def test_generate_max_connections_per_node_float(self):
"""Test generate method raises ValueError when max_connections_per_node is a float."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Max connections per node must be a positive integer. Got 4.5"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10, max_connections_per_node=4.5)
assert str(excinfo.value) == "Max connections per node must be a positive integer. Got 4.5"

def test_generate_maximize_string(self):
"""Test generate method raises ValueError when maximize is a string."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Maximize must be a boolean. Got true"):
MaxKColorGenerator.generate(seed=SEED, maximize="true")
assert str(excinfo.value) == "Maximize must be a boolean. Got true"

def test_generate_zero_nodes(self):
"""Test generate method raises ValueError when number_of_nodes is zero."""
with pytest.raises(ValueError, match="Number of nodes must be a positive integer. Got 0"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=0)

def test_generate_no_edges(self):
"""Test generate method with no possible connections."""
with pytest.raises(ValueError, match="Max connections per node must be a positive integer. Got 0"):
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10, max_connections_per_node=0)

def test_generate_default_parameters(self):
"""Test generate method with default parameters."""
problem = MaxKColorGenerator.generate(seed=SEED)

assert problem.length == 20
assert problem.source_graph.number_of_edges() > 0

def test_generate_maximum_colors(self):
"""Test generate method with maximum number of colors."""
number_of_nodes = 5
max_colors = 100
problem = MaxKColorGenerator.generate(seed=SEED, number_of_nodes=number_of_nodes, max_colors=max_colors)

assert problem.max_val == max_colors

def test_generate_single_node_one_connection(self):
"""Test generate method with one node and up to one connection."""
Expand Down Expand Up @@ -92,12 +111,6 @@ def test_generate_large_graph(self):
assert problem.length == number_of_nodes
assert problem.source_graph.number_of_edges() > 0

def test_generate_no_edges(self):
"""Test generate method with no possible connections."""
with pytest.raises(ValueError) as excinfo:
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=10, max_connections_per_node=0)
assert str(excinfo.value) == "Max connections per node must be a positive integer. Got 0"

def test_generate_max_colors_none(self):
"""Test generate method with max_colors set to None."""
number_of_nodes = 5
Expand All @@ -107,12 +120,6 @@ def test_generate_max_colors_none(self):
assert problem.length == number_of_nodes
assert problem.max_val > 1

def test_generate_zero_nodes(self):
"""Test generate method raises ValueError when number_of_nodes is zero."""
with pytest.raises(ValueError) as excinfo:
MaxKColorGenerator.generate(seed=SEED, number_of_nodes=0)
assert str(excinfo.value) == "Number of nodes must be a positive integer. Got 0"

def test_generate_large_max_colors(self):
"""Test generate method with a large max_colors value."""
number_of_nodes = 10
Expand All @@ -133,3 +140,46 @@ def test_generate_large_max_connections(self):

assert problem.length == number_of_nodes
assert problem.source_graph.number_of_edges() > 0

def test_generate_unreachable_nodes(self):
"""Test generate method adds edges to ensure graph connectivity when some nodes are unreachable."""
number_of_nodes = 4
max_connections_per_node = 1 # Low connections to increase likelihood of disconnected components
problem = MaxKColorGenerator.generate(seed=SEED, number_of_nodes=number_of_nodes, max_connections_per_node=max_connections_per_node)

# Check that all nodes are reachable from any other node
for node in range(number_of_nodes):
assert len(nx.bfs_tree(problem.source_graph, node).nodes) == number_of_nodes

def test_generate_fully_connected_graph(self):
"""Test generate method creates a fully connected graph when max_connections_per_node is large enough."""
number_of_nodes = 4
max_connections_per_node = number_of_nodes - 1 # Enough to potentially fully connect the graph
problem = MaxKColorGenerator.generate(seed=SEED, number_of_nodes=number_of_nodes, max_connections_per_node=max_connections_per_node)

# Check that all nodes are reachable from any other node
for node in range(number_of_nodes):
assert len(nx.bfs_tree(problem.source_graph, node).nodes) == number_of_nodes
assert problem.source_graph.number_of_edges() == number_of_nodes * (number_of_nodes - 1) / 2 # Full connectivity

def test_generate_single_disconnected_node(self):
"""Test generate method adds an edge to connect a single disconnected node."""
number_of_nodes = 3
max_connections_per_node = 1 # Force a disconnected node scenario
problem = MaxKColorGenerator.generate(seed=SEED, number_of_nodes=number_of_nodes, max_connections_per_node=max_connections_per_node)

# Check that all nodes are reachable from any other node
for node in range(number_of_nodes):
assert len(nx.bfs_tree(problem.source_graph, node).nodes) == number_of_nodes

def test_generate_unreachable_nodes_coverage(self):
"""Test generate method adds edges to ensure connectivity when some nodes are initially unreachable."""
number_of_nodes = 6
max_connections_per_node = 1 # Low connections to increase likelihood of disconnected components

# Use a seed that leads to unreachable nodes
problem = MaxKColorGenerator.generate(seed=100, number_of_nodes=number_of_nodes, max_connections_per_node=max_connections_per_node)

# Check that all nodes are reachable from any other node
for node in range(number_of_nodes):
assert len(nx.bfs_tree(problem.source_graph, node).nodes) == number_of_nodes
94 changes: 82 additions & 12 deletions tests/test_generators/test_tsp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,24 @@ class TestTSPGenerator:

def test_generate_invalid_seed(self):
"""Test generate method raises ValueError when seed is not an integer."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Seed must be an integer. Got not_an_int"):
TSPGenerator.generate(5, seed="not_an_int")

assert str(excinfo.value) == "Seed must be an integer. Got not_an_int"

def test_generate_invalid_number_of_cities(self):
"""Test generate method raises ValueError when number_of_cities is not a positive integer."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Number of cities must be a positive integer. Got 0"):
TSPGenerator.generate(0, seed=SEED)

assert str(excinfo.value) == "Number of cities must be a positive integer. Got 0"

def test_generate_invalid_area_width(self):
"""Test generate method raises ValueError when area_width is not a positive integer."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Area width must be a positive integer. Got 0"):
TSPGenerator.generate(5, area_width=0, seed=SEED)

assert str(excinfo.value) == "Area width must be a positive integer. Got 0"

def test_generate_invalid_area_height(self):
"""Test generate method raises ValueError when area_height is not a positive integer."""
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match="Area height must be a positive integer. Got 0"):
TSPGenerator.generate(5, area_height=0, seed=SEED)

assert str(excinfo.value) == "Area height must be a positive integer. Got 0"

def test_generate_default_parameters(self):
"""Test generate method with default parameters."""
num_cities = 5
Expand Down Expand Up @@ -82,3 +74,81 @@ def test_generate_graph(self):

assert graph.number_of_nodes() == num_cities
assert graph.number_of_edges() == len(problem.distances)

def test_generate_one_city(self):
"""Test generate method with the minimum number of cities (1)."""
num_cities = 1
problem = TSPGenerator.generate(num_cities, seed=SEED)

assert problem.length == num_cities
assert len(problem.coords) == num_cities
assert len(problem.distances) == 0 # No distances when there's only one city
assert problem.source_graph.number_of_nodes() == num_cities
assert problem.source_graph.number_of_edges() == 0

def test_generate_large_number_of_cities(self):
"""Test generate method with a large number of cities."""
num_cities = 1000
problem = TSPGenerator.generate(num_cities, seed=SEED)

assert problem.length == num_cities
assert len(problem.coords) == num_cities
assert len(problem.distances) > 0
assert problem.source_graph.number_of_nodes() == num_cities
assert problem.source_graph.number_of_edges() == len(problem.distances)

def test_generate_non_square_area(self):
"""Test generate method with different area width and height."""
num_cities = 5
area_width = 300
area_height = 100
problem = TSPGenerator.generate(num_cities, area_width=area_width, area_height=area_height, seed=SEED)

assert problem.length == num_cities
for x, y in problem.coords:
assert 0 <= x < area_width
assert 0 <= y < area_height

def test_generate_same_seed_reproducibility(self):
"""Test generate method produces the same results when called with the same seed."""
num_cities = 5
problem1 = TSPGenerator.generate(num_cities, seed=SEED)
problem2 = TSPGenerator.generate(num_cities, seed=SEED)

assert problem1.coords == problem2.coords
assert problem1.distances == problem2.distances

def test_generate_randomness_different_seeds(self):
"""Test generate method produces different results with different seeds."""
num_cities = 5
problem1 = TSPGenerator.generate(num_cities, seed=45)
problem2 = TSPGenerator.generate(num_cities, seed=56)

assert problem1.coords != problem2.coords
assert problem1.distances != problem2.distances

def test_generate_negative_area_dimensions(self):
"""Test generate method raises ValueError when area dimensions are negative."""
with pytest.raises(ValueError, match="Area width must be a positive integer. Got -100"):
TSPGenerator.generate(5, area_width=-100, seed=SEED)

with pytest.raises(ValueError, match="Area height must be a positive integer. Got -100"):
TSPGenerator.generate(5, area_height=-100, seed=SEED)

def test_get_distances_truncate(self):
"""Test get_distances method truncates distances to integers when truncate=True."""
coords = [(0, 0), (3, 4), (6, 8)] # Known distances: 5.0 and 5.0
expected_distances = [(0, 1, 5), (0, 2, 10), (1, 2, 5)] # Truncated distances

distances = TSPGenerator.get_distances(coords)

assert distances == expected_distances

def test_get_distances_no_truncate(self):
"""Test get_distances method retains full precision when truncate=False."""
coords = [(0, 0), (3, 4), (6, 8)] # Known distances: 5.0 and 10.0
expected_distances = [(0, 1, 5.0), (0, 2, 10.0), (1, 2, 5.0)] # Non-truncated distances

distances = TSPGenerator.get_distances(coords, truncate=False)

assert distances == expected_distances

0 comments on commit 1b2e91e

Please sign in to comment.