Skip to content

Commit

Permalink
fixed multiprocessing for Windows; fixed rsteca#44; fixed; incorporat…
Browse files Browse the repository at this point in the history
…ed changes from pull request: rsteca#45
  • Loading branch information
Alexander McFarlane committed Jun 8, 2018
1 parent 6468860 commit 290bab7
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions evolutionary_search/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.metrics.scorer import check_scoring
from sklearn.utils.validation import _num_samples, indexable


def enum(**enums):
return type('Enum', (), enums)

Expand Down Expand Up @@ -278,6 +279,8 @@ class EvolutionaryAlgorithmSearchCV(BaseSearchCV):
With the statistics of the evolution.
"""
best_score_ = None
best_params_ = None
def __init__(self, estimator, params, scoring=None, cv=4,
refit=True, verbose=False, population_size=50,
gene_mutation_prob=0.1, gene_crossover_prob=0.5,
Expand All @@ -298,12 +301,12 @@ def __init__(self, estimator, params, scoring=None, cv=4,
self.gene_type = gene_type
self.all_history_, self.all_logbooks_ = [], []
self._cv_results = None
self.best_score_ = None
self.best_params_ = None
self.score_cache = {}
self.n_jobs = n_jobs
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, est=clone(self.estimator), fitness=creator.FitnessMax)
if "FitnessMax" not in creator.__dict__:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
if "Individual" not in creator.__dict__:
creator.create("Individual", list, est=clone(self.estimator), fitness=creator.FitnessMax)

@property
def possible_params(self):
Expand Down Expand Up @@ -358,9 +361,11 @@ def fit(self, X, y=None):
self.best_estimator_ = clone(self.estimator)
self.best_estimator_.set_params(**self.best_mem_params_)
if self.fit_params is not None:

self.best_estimator_.fit(X, y, **self.fit_params)
else:
self.best_estimator_.fit(X, y)
return self

def _fit(self, X, y, parameter_dict):
self._cv_results = None # To indicate to the property the need to update
Expand Down Expand Up @@ -394,16 +399,16 @@ def _fit(self, X, y, parameter_dict):
# wrapper so that pools are not recursively created when the module is reloaded in each map
if isinstance(self.n_jobs, int):
if self.n_jobs > 1 or self.n_jobs < 0:
from multiprocessing import Pool # Only imports if needed
if os.name == 'nt': # Checks if we are on Windows
warnings.warn(("Windows requires Pools to be declared from within "
"an \'if __name__==\"__main__\":\' structure. In this "
"case, n_jobs will accept map functions as well to "
"facilitate custom parallelism. Please check to see "
"that all code is working as expected."))
pool = Pool(self.n_jobs)
toolbox.register("map", pool.map)

if __name__ == '__main__':
from multiprocessing import Pool # Only imports if needed
if os.name == 'nt': # Checks if we are on Windows
warnings.warn(("Windows requires Pools to be declared from within "
"an \'if __name__==\"__main__\":\' structure. In this "
"case, n_jobs will accept map functions as well to "
"facilitate custom parallelism. Please check to see "
"that all code is working as expected."))
pool = Pool(self.n_jobs)
toolbox.register("map", pool.map)
# If it's not an int, we are going to pass it as the map directly
else:
try:
Expand Down Expand Up @@ -463,8 +468,9 @@ def _fit(self, X, y, parameter_dict):

# Close your pools if you made them
if isinstance(self.n_jobs, int) and (self.n_jobs > 1 or self.n_jobs < 0):
pool.close()
pool.join()
if __name__ == '__main__':
pool.close()
pool.join()

self.best_score_ = current_best_score_
self.best_params_ = current_best_params_

0 comments on commit 290bab7

Please sign in to comment.