Skip to content

Commit

Permalink
tests: 100% coverage for _nn_runner_base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Oct 8, 2024
1 parent d48eb43 commit 5fb8e4c
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/test_runners/test_nn_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,78 @@ def test_nn_runner_base_grid_search_score_intercept(self):
score = runner._grid_search_score_intercept(y_true, y_pred)

assert np.isnan(score)

def test_nn_runner_base_teardown_rename_with_existing_file(self):
"""Test _NNRunnerBase _tear_down method when correct_filename already exists"""
with patch("os.makedirs"):
runner = _NNRunnerBase(
x_train=np.random.rand(100, 10),
y_train=np.random.randint(2, size=100),
x_test=np.random.rand(20, 10),
y_test=np.random.randint(2, size=20),
experiment_name="test_experiment",
seed=SEED,
iteration_list=[1, 2, 3],
grid_search_parameters={"param1": [0.1, 0.2], "param2": [1, 2]},
grid_search_scorer_method=skmt.accuracy_score,
output_directory="test_output",
)

runner.runner_name = MagicMock(return_value="TestRunner")
runner.best_params = {"param1": 0.1, "param2": 1}
runner.replay_mode = MagicMock(return_value=False)
runner._check_match = MagicMock(return_value=True) # All files are correct

# Mock the necessary functions and data
filename_root = "test_output/test_experiment/testrunner__test_experiment"
path = "test_output/test_experiment"
filename_part = "testrunner__test_experiment"

# Prepare filenames with correct md5 hash
correct_md5 = "ABCDEF123456"
correct_filename_with_md5 = f"{filename_part}_df_{correct_md5}.p"
correct_filename = correct_filename_with_md5.replace(f"__{correct_md5}", "")
filenames = [correct_filename_with_md5]

# Define correct dataframe
correct_df = pd.DataFrame([{"param1": "0.1", "param2": "1"}])

# Helper function to return different data based on filename
def mock_pickle_load(file):
if correct_md5 in file.name:
return correct_df
return pd.DataFrame()

# Mock open to set the filename attribute
mock_file_correct = mock_open(read_data=pk.dumps(correct_df)).return_value
mock_file_correct.name = os.path.join(path, correct_filename_with_md5)

# Side effect for open to return different mock files
def open_side_effect(file, mode="rb"):
if correct_md5 in file:
return mock_file_correct
return mock_open().return_value

with (
patch("mlrose_ky.runners._runner_base._RunnerBase._get_pickle_filename_root", return_value=filename_root),
patch("os.path.isdir", return_value=True),
patch("os.listdir", return_value=filenames),
patch("os.rename") as mock_rename,
patch("os.path.exists", return_value=True), # Mock os.path.exists to return True
patch("builtins.open", side_effect=open_side_effect),
patch("pickle.load", side_effect=mock_pickle_load),
patch.object(runner, "_check_match", return_value=True),
):

runner._tear_down()

# Check that os.rename was called twice:
# First to backup existing file, then to rename the new correct file
correct_file_path = os.path.join(path, correct_filename)
correct_file_with_md5_path = os.path.join(path, correct_filename_with_md5)
backup_file_path = f"{correct_file_path}.bak"

expected_calls = [call(correct_file_path, backup_file_path), call(correct_file_with_md5_path, correct_file_path)]

mock_rename.assert_has_calls(expected_calls)
assert mock_rename.call_count == 2

0 comments on commit 5fb8e4c

Please sign in to comment.