Skip to content

Commit

Permalink
tests: 100% coverage for .samples
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Sep 15, 2024
1 parent 9e993bb commit 164d466
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tests/test_runners/test_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestRunnerBase:
@pytest.fixture
def test_runner(self):
def _test_runner_fixture(self):
"""Fixture to create a TestRunner instance for testing."""

# noinspection PyMissingOrEmptyDocstring
Expand All @@ -38,26 +38,26 @@ def _create_runner(**kwargs):

return _create_runner

def test_increment_spawn_count(self, test_runner):
def test_increment_spawn_count(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner()
runner = _test_runner_fixture()
initial_count = runner._get_spawn_count()
runner._increment_spawn_count()

assert runner._get_spawn_count() == initial_count + 1

def test_decrement_spawn_count(self, test_runner):
def test_decrement_spawn_count(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner()
runner = _test_runner_fixture()
runner._increment_spawn_count()
initial_count = runner._get_spawn_count()
runner._decrement_spawn_count()

assert runner._get_spawn_count() == initial_count - 1

def test_get_spawn_count(self, test_runner):
def test_get_spawn_count(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner()
runner = _test_runner_fixture()
initial_spawn_count = runner._get_spawn_count()
runner._increment_spawn_count()
incremented_spawn_count = runner._get_spawn_count()
Expand All @@ -67,23 +67,23 @@ def test_get_spawn_count(self, test_runner):
decremented_spawn_count = runner._get_spawn_count()
assert decremented_spawn_count == initial_spawn_count

def test_abort_sets_abort_flag(self, test_runner):
def test_abort_sets_abort_flag(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner()
runner = _test_runner_fixture()
runner.abort()

assert runner.has_aborted() is True

def test_has_aborted_after_abort_called(self, test_runner):
def test_has_aborted_after_abort_called(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner(seed=SEED, iteration_list=[0])
runner = _test_runner_fixture(seed=SEED, iteration_list=[0])
runner.abort()

assert runner.has_aborted() is True

def test_set_replay_mode(self, test_runner):
def test_set_replay_mode(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner()
runner = _test_runner_fixture()
assert not runner.replay_mode()

runner.set_replay_mode()
Expand All @@ -92,17 +92,17 @@ def test_set_replay_mode(self, test_runner):
runner.set_replay_mode(False)
assert not runner.replay_mode()

def test_replay_mode(self, test_runner):
def test_replay_mode(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner(replay=True)
runner = _test_runner_fixture(replay=True)
assert runner.replay_mode() is True

runner.set_replay_mode(False)
assert runner.replay_mode() is False

def test_setup_method(self, test_runner):
def test_setup_method(self, _test_runner_fixture):
with patch("os.makedirs") as mock_makedirs, patch("os.path.exists", return_value=False):
runner = test_runner(problem="dummy_problem", seed=SEED, iteration_list=[0, 1, 2], output_directory="test_output")
runner = _test_runner_fixture(problem="dummy_problem", seed=SEED, iteration_list=[0, 1, 2], output_directory="test_output")
runner._setup()

assert runner._raw_run_stats == []
Expand All @@ -113,18 +113,18 @@ def test_setup_method(self, test_runner):
assert runner._current_logged_algorithm_args == {}
mock_makedirs.assert_called_once_with("test_output")

def test_tear_down_restores_original_sigint_handler(self, test_runner):
def test_tear_down_restores_original_sigint_handler(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
original_handler = signal.getsignal(signal.SIGINT)
runner = test_runner()
runner = _test_runner_fixture()
runner._tear_down()
restored_handler = signal.getsignal(signal.SIGINT)

assert restored_handler == original_handler

def test_log_current_argument(self, test_runner):
def test_log_current_argument(self, _test_runner_fixture):
with patch("os.makedirs"), patch("os.path.exists", return_value=True):
runner = test_runner(seed=SEED, iteration_list=[0, 1, 2])
runner = _test_runner_fixture(seed=SEED, iteration_list=[0, 1, 2])
arg_name = "test_arg"
arg_value = "test_value"
runner._log_current_argument(arg_name, arg_value)
Expand Down
95 changes: 95 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
import pytest
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

from tests.globals import SEED
from mlrose_ky.samples import SyntheticData, plot_synthetic_dataset
Expand Down Expand Up @@ -63,3 +64,97 @@ def test_plot_synthetic_dataset(self, mock_show, generator):
plt.close()

assert mock_show.called

@patch("matplotlib.pyplot.show")
def test_plot_synthetic_dataset_with_predict_proba(self, mock_show, generator):
"""Test plotting synthetic dataset with a classifier that uses predict_proba."""
data, _, _, _ = generator.get_synthetic_data()
x, y, x_train, x_test, y_train, y_test = generator.setup_synthetic_data_test_train(data)
classifier = RandomForestClassifier(random_state=generator.seed).fit(x_train, y_train)
plt.figure()
plot_synthetic_dataset(x_train, x_test, y_train, y_test, classifier=classifier)
plt.close()

assert mock_show.called

def test_get_synthetic_data_with_noise(self, generator):
"""Test getting synthetic data with noise."""
data, features, classes, output_directory = generator.get_synthetic_data(add_noise=0.1)

assert data.shape[0] > 400 # Noise adds more data points
assert features == ["(1) A", "(2) B"]
assert classes == ["RED", "BLUE"]
assert output_directory is None

def test_get_synthetic_data_with_redundant_column(self, generator):
"""Test getting synthetic data with a redundant column."""
data, features, classes, output_directory = generator.get_synthetic_data(add_redundant_column=True)

assert data.shape == (400, 4) # Additional column for redundancy
assert features == ["(1) A", "(2) B", "(3) R"]
assert classes == ["RED", "BLUE"]
assert output_directory is None

@patch("mlrose_ky.samples.synthetic_data.makedirs")
def test_get_synthetic_data_with_root_directory(self, mock_makedirs, generator):
"""Test getting synthetic data with a root directory."""
generator.root_directory = "/tmp/synthetic_data"
data, features, classes, output_directory = generator.get_synthetic_data()

assert data.shape == (400, 3)
assert features == ["(1) A", "(2) B"]
assert classes == ["RED", "BLUE"]
assert output_directory is not None
mock_makedirs.assert_called_once_with(output_directory)

@patch("mlrose_ky.samples.synthetic_data.makedirs")
def test_get_synthetic_data_with_root_directory_oserror(self, mock_makedirs, generator):
"""Test getting synthetic data with a root directory where makedirs raises OSError."""
# Arrange: Set the side effect of makedirs to raise OSError
mock_makedirs.side_effect = OSError("Test OSError")

# Act: Set the root directory and call get_synthetic_data
generator.root_directory = "/tmp/synthetic_data"
data, features, classes, output_directory = generator.get_synthetic_data()

# Assert: Check that data is returned correctly even when OSError is raised
assert data.shape == (400, 3)
assert features == ["(1) A", "(2) B"]
assert classes == ["RED", "BLUE"]
assert output_directory is not None
mock_makedirs.assert_called_once_with(output_directory)

@patch("matplotlib.pyplot.show")
def test_plot_synthetic_dataset_without_classifier(self, mock_show, generator):
"""Test plotting synthetic dataset without a classifier."""
data, _, _, _ = generator.get_synthetic_data()
x, y, x_train, x_test, y_train, y_test = generator.setup_synthetic_data_test_train(data)
plt.figure()
plot_synthetic_dataset(x_train, x_test, y_train, y_test)
plt.close()

assert mock_show.called

@patch("matplotlib.pyplot.show")
def test_plot_synthetic_dataset_with_redundant_column(self, mock_show, generator):
"""Test plotting synthetic dataset with a redundant column (three features)."""
data, _, _, _ = generator.get_synthetic_data(add_redundant_column=True)
x, y, x_train, x_test, y_train, y_test = generator.setup_synthetic_data_test_train(data)
classifier = LogisticRegression().fit(x_train, y_train)
plt.figure()
plot_synthetic_dataset(x_train, x_test, y_train, y_test, classifier=classifier)
plt.close()

assert mock_show.called

@patch("matplotlib.pyplot.show")
def test_plot_synthetic_dataset_with_transparent_bg(self, mock_show, generator):
"""Test plotting synthetic dataset with transparent background and custom background color."""
data, _, _, _ = generator.get_synthetic_data()
x, y, x_train, x_test, y_train, y_test = generator.setup_synthetic_data_test_train(data)
classifier = LogisticRegression().fit(x_train, y_train)
plt.figure()
plot_synthetic_dataset(x_train, x_test, y_train, y_test, classifier=classifier, transparent_bg=True, bg_color="black")
plt.close()

assert mock_show.called

0 comments on commit 164d466

Please sign in to comment.