Skip to content

Commit

Permalink
Refactor .runners: _NNRunnerBase, _RunnerBase, NNGSRunner, and SKMLPR…
Browse files Browse the repository at this point in the history
…unner (#11)
  • Loading branch information
knakamura13 committed Sep 5, 2024
1 parent 70a6e42 commit e51ea60
Show file tree
Hide file tree
Showing 4 changed files with 526 additions and 211 deletions.
177 changes: 75 additions & 102 deletions src/mlrose_ky/runners/_nn_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pickle as pk
import time
from abc import ABC
from typing import Callable
from typing import Callable, Any

import numpy as np
import pandas as pd
Expand All @@ -20,8 +20,8 @@

class _NNRunnerBase(_RunnerBase, GridSearchMixin, ABC):
"""
A base class for running neural network experiments with grid search. It extends functionality from
_RunnerBase and GridSearchMixin.
A base class for running neural network experiments with grid search.
It extends functionality from _RunnerBase and GridSearchMixin.
This class provides methods for setting up and executing grid search over neural network hyperparameters,
handling cross-validation, saving results, and managing file operations.
Expand All @@ -40,25 +40,25 @@ class _NNRunnerBase(_RunnerBase, GridSearchMixin, ABC):
Name of the experiment.
seed : int
Random seed to ensure reproducibility.
iteration_list : list
iteration_list : list[int]
List of iteration counts to perform.
grid_search_parameters : dict
Hyperparameters for grid search.
cv : int
Number of cross-validation folds.
generate_curves : bool
Whether to generate learning curves during the experiment.
_output_directory : str or None
_output_directory : str | None
Directory where outputs will be saved.
verbose_grid_search : bool
Whether to output detailed grid search information.
override_ctrl_c_handler : bool
Whether to override the default CTRL+C handler.
n_jobs : int
Number of parallel jobs for grid search.
cv_results_df : pd.DataFrame or None
cv_results_df : pd.DataFrame | None
DataFrame to store cross-validation results.
best_params : dict or None
best_params : dict | None
Dictionary storing the best parameters found during grid search.
"""

Expand All @@ -72,7 +72,7 @@ def __init__(
y_test: np.ndarray,
experiment_name: str,
seed: int,
iteration_list: list,
iteration_list: list[int],
grid_search_parameters: dict,
grid_search_scorer_method: Callable,
cv: int = 5,
Expand All @@ -83,7 +83,7 @@ def __init__(
n_jobs: int = 1,
replay: bool = False,
**kwargs,
):
) -> None:
"""
Initializes the _NNRunnerBase class.
Expand All @@ -101,30 +101,29 @@ def __init__(
Name of the experiment.
seed : int
Random seed to ensure reproducibility.
iteration_list : list
iteration_list : list[int]
List of iteration counts to perform.
grid_search_parameters : dict
Hyperparameters for grid search.
grid_search_scorer_method : Callable
Scorer method for evaluating the grid search.
cv : int, optional
Number of cross-validation folds, by default 5.
generate_curves : bool, optional
Whether to generate learning curves, by default True.
output_directory : str or None, optional
Directory where outputs will be saved, by default None.
verbose_grid_search : bool, optional
Whether to output detailed grid search information, by default True.
override_ctrl_c_handler : bool, optional
Whether to override the default CTRL+C handler, by default True.
n_jobs : int, optional
Number of parallel jobs for grid search, by default 1.
replay : bool, optional
Whether to replay previous results, by default False.
cv : int, optional, default=5
Number of cross-validation folds.
generate_curves : bool, optional, default=True
Whether to generate learning curves.
output_directory : str | None, optional, default=None
Directory where outputs will be saved.
verbose_grid_search : bool, optional, default=True
Whether to output detailed grid search information.
override_ctrl_c_handler : bool, optional, default=True
Whether to override the default CTRL+C handler.
n_jobs : int, optional, default=1
Number of parallel jobs for grid search.
replay : bool, optional, default=False
Whether to replay previous results.
**kwargs :
Additional hyperparameters for grid search.
"""
# Initialize the _RunnerBase class with common parameters.
super().__init__(
problem=None,
experiment_name=experiment_name,
Expand All @@ -137,22 +136,19 @@ def __init__(
copy_zero_curve_fitness_from_first=True,
)

# Initialize GridSearchMixin with the grid search scorer method.
GridSearchMixin.__init__(self, scorer_method=grid_search_scorer_method)

self.classifier = None

# Build grid search parameters from provided arguments and additional keyword arguments.
self.grid_search_parameters = self.build_grid_search_parameters(grid_search_parameters=grid_search_parameters, **kwargs)
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
self.cv = cv
self.n_jobs = n_jobs
self.verbose_grid_search = verbose_grid_search
self.cv_results_df = None
self.best_params = None
self.classifier: Any = None
self.grid_search_parameters: dict = self.build_grid_search_parameters(grid_search_parameters, **kwargs)
self.x_train: np.ndarray = x_train
self.y_train: np.ndarray = y_train
self.x_test: np.ndarray = x_test
self.y_test: np.ndarray = y_test
self.cv: int = cv
self.n_jobs: int = n_jobs
self.verbose_grid_search: bool = verbose_grid_search
self.cv_results_df: pd.DataFrame | None = None
self.best_params: dict | None = None

def dynamic_runner_name(self) -> str:
"""
Expand All @@ -165,7 +161,7 @@ def dynamic_runner_name(self) -> str:
"""
return f"{self.__class__.__name__}_{self._experiment_name}"

def run(self):
def run(self) -> tuple[pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame | None, Any | None]:
"""
Executes the runner, performing grid search and handling the results.
Expand All @@ -176,14 +172,15 @@ def run(self):
"""
try:
self._setup()

logging.info(f"Running experiment: {self._experiment_name}")

# Replay mode allows reusing previous grid search results, avoiding re-execution
if self.replay_mode():
# Load previous grid search results if in replay mode
gsr_name = f"{super()._get_pickle_filename_root('grid_search_results')}.p"
with open(gsr_name, "rb") as pickle_file:
search_results = pk.load(pickle_file)
else:
# Perform grid search and measure run time
run_start = time.perf_counter()
search_results = self.perform_grid_search(
classifier=self.classifier,
Expand All @@ -197,22 +194,22 @@ def run(self):
run_end = time.perf_counter()
logging.info(f"Run time: {run_end - run_start}")

# Update the runner's attributes with the best estimator's attributes
# Updates internal attributes with the best estimator found by grid search
self.__dict__.update(search_results.best_estimator_.runner.__dict__)

self.best_params = search_results.best_params_
# Save cross-validation results to disk

self.cv_results_df = self._make_cv_results_data_frame(search_results.cv_results_)
extra_data_frames = {"cv_results_df": self.cv_results_df}

self._create_and_save_run_data_frames(extra_data_frames=extra_data_frames, final_save=True)

# Save grid search results to disk
try:
# Save the grid search results to disk
self._dump_pickle_to_disk(search_results, "grid_search_results", final_save=True)
except (OSError, IOError, pk.PickleError):
pass

# Predict and score the model with the best estimator
# Perform predictions and score using the best estimator found in the search
try:
y_pred = search_results.best_estimator_.predict(self.x_test)
score = self.score(y_pred=y_pred, y_true=self.y_train)
Expand All @@ -222,10 +219,9 @@ def run(self):

return self.run_stats_df, self.curves_df, self.cv_results_df, search_results
except KeyboardInterrupt:
# Handle early termination gracefully
# Handling graceful termination in case of manual interruption
return None, None, None, None
finally:
# Cleanup after the run is complete
self._tear_down()

def _get_pickle_filename_root(self, name: str) -> str:
Expand All @@ -243,11 +239,12 @@ def _get_pickle_filename_root(self, name: str) -> str:
The root filename with a hash appended.
"""
filename_root = super()._get_pickle_filename_root(name)

# Create a unique hash based on argument values (excluding state-related keys)
arg_text = "".join([f"{k}_{self._sanitize_value(v)}_" for k, v in self._current_logged_algorithm_args.items() if "state" not in k])
arg_hash = f"__{hashlib.md5(arg_text.encode()).hexdigest()}".upper() if len(arg_text) > 0 else ""
filename_root += arg_hash

return filename_root
return filename_root + arg_hash

@staticmethod
def _check_match(df_reference: pd.DataFrame, df_to_check: pd.DataFrame) -> bool:
Expand All @@ -267,54 +264,39 @@ def _check_match(df_reference: pd.DataFrame, df_to_check: pd.DataFrame) -> bool:
True if a matching row is found, False otherwise.
"""
cols = [col for col in df_reference.columns]
found = False

for _, row in df_to_check.iterrows():
found = True
for col in cols:
if df_reference[col][0] != row[col]:
found = False
break
if found:
break
# Returns True as soon as a match is found
if all(df_reference[col][0] == row[col] for col in cols):
return True

return found
return False

def _tear_down(self, filename: str | None = None):
def _tear_down(self, filename: str | None = None) -> None:
"""
Finalizes the runner, ensuring that the proper files are saved or cleaned up.
Parameters
----------
filename : str or None, optional
Filename to clean up, by default None.
filename : str | None, optional, default=None
Filename to clean up.
"""
if self.best_params is None or self.replay_mode() is None or self._output_directory is None:
super()._tear_down()
return

filename_root = super()._get_pickle_filename_root("")
print(f"Filename root: {filename_root}")

path = os.path.join(*filename_root.split(os.sep)[:-1])
filename_part = filename_root.split(os.sep)[-1]
print(f"Path: {path}")
print(f"Filename part: {filename_part}")

if not os.path.isdir(path) and path[0] != os.sep:
path = f"{os.sep}{path}"

# Ensure the directory exists
print(f"Final path after adjustment: {path}")

filenames = [fn for fn in os.listdir(str(path)) if (filename_part in fn and fn.endswith(".p") and "_df_" in fn)]

print(f"Filenames found: {filenames}")

if not filenames:
raise FileNotFoundError(f"No matching filenames found in path: {path}")

# Create a DataFrame from the best parameters
df_best_params = pd.DataFrame([{k: self._sanitize_value(v) for k, v in self.best_params.items()}])

correct_files = []
Expand All @@ -325,32 +307,28 @@ def _tear_down(self, filename: str | None = None):
try:
df = pk.load(pickle_file)
found = self._check_match(df_best_params, df)
if not found:
incorrect_files.append(filename)
else:
correct_files.append(filename)
(correct_files if found else incorrect_files).append(filename)
except (EOFError, pk.PickleError):
pass

# Extract the md5 hashes from the filenames of correct and incorrect files
correct_md5s = list(set([p.split("_")[-1][:-2] for p in correct_files]))
incorrect_md5s = list(set([p.split("_")[-1][:-2] for p in incorrect_files]))

# Remove the suboptimal files based on the incorrect md5 hashes
all_incorrect_files = []
for incorrect_md5 in incorrect_md5s:
all_incorrect_files.extend([os.path.join(str(path), fn) for fn in os.listdir(str(path)) if incorrect_md5 in fn])
# Extracts md5 hashes from correct and incorrect files for renaming or deletion
correct_md5s = {p.split("_")[-1][:-2] for p in correct_files}
incorrect_md5s = {p.split("_")[-1][:-2] for p in incorrect_files}

# Remove files corresponding to incorrect md5 hashes
all_incorrect_files = [
os.path.join(str(path), fn) for incorrect_md5 in incorrect_md5s for fn in os.listdir(str(path)) if incorrect_md5 in fn
]
for _filename in all_incorrect_files:
os.remove(_filename)

# Rename the best files by removing the md5 from the filename
all_correct_files = []
for _correct_md5 in correct_md5s:
all_correct_files.extend(
[(os.path.join(str(path), fn), f"__{_correct_md5}") for fn in os.listdir(str(path)) if _correct_md5 in fn]
)

# Rename correct files by removing md5 hash from the filename
all_correct_files = [
(os.path.join(str(path), fn), f"__{_correct_md5}")
for _correct_md5 in correct_md5s
for fn in os.listdir(str(path))
if _correct_md5 in fn
]
for _filename, _correct_md5 in all_correct_files:
correct_filename = _filename.replace(_correct_md5, "")
if os.path.exists(correct_filename):
Expand Down Expand Up @@ -384,7 +362,6 @@ def _make_cv_results_data_frame(cv_results: dict) -> pd.DataFrame:
param_label = p.replace(param_prefix, "")
new_param_values[p].append(_NNRunnerBase._sanitize_value(v[param_label]))

# Replace original parameter values with sanitized values
cv_results.update(new_param_values)
df = pd.DataFrame(cv_results)
df.dropna(inplace=True)
Expand All @@ -408,11 +385,7 @@ def build_grid_search_parameters(grid_search_parameters: dict, **kwargs) -> dict
dict
Combined grid search parameters.
"""
all_grid_search_parameters = {}
all_grid_search_parameters.update(grid_search_parameters)
all_grid_search_parameters.update(**kwargs)

return all_grid_search_parameters
return {**grid_search_parameters, **kwargs}

def _grid_search_score_intercept(
self, y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, adjusted: bool = False
Expand All @@ -426,10 +399,10 @@ def _grid_search_score_intercept(
True labels.
y_pred : np.ndarray
Predicted labels.
sample_weight : np.ndarray or None, optional
Sample weights for scoring, by default None.
adjusted : bool, optional
Whether to adjust the score, by default False.
sample_weight : np.ndarray | None, optional, default=None
Sample weights for scoring.
adjusted : bool, optional, default=False
Whether to adjust the score.
Returns
-------
Expand Down
Loading

0 comments on commit e51ea60

Please sign in to comment.