diff --git a/pypesto/optimize/ess/ess.py b/pypesto/optimize/ess/ess.py index 2a86e18d9..74d88f641 100644 --- a/pypesto/optimize/ess/ess.py +++ b/pypesto/optimize/ess/ess.py @@ -37,6 +37,8 @@ class ESSExitFlag(int, enum.Enum): MAX_EVAL = -2 # Exited after exhausting wall-time budget MAX_TIME = -3 + # Termination because for other reason than exit criteria + ERROR = -99 class OptimizerFactory(Protocol): diff --git a/pypesto/optimize/ess/sacess.py b/pypesto/optimize/ess/sacess.py index 113310f25..a973c51a0 100644 --- a/pypesto/optimize/ess/sacess.py +++ b/pypesto/optimize/ess/sacess.py @@ -7,6 +7,7 @@ import multiprocessing import os import time +from contextlib import suppress from dataclasses import dataclass from math import ceil, sqrt from multiprocessing import get_context @@ -20,6 +21,7 @@ import pypesto +from ... import MemoryHistory from ...startpoint import StartpointMethod from ...store.read_from_hdf5 import read_result from ...store.save_to_hdf5 import write_result @@ -331,12 +333,18 @@ def minimize( n_eval_total = sum( worker_result.n_eval for worker_result in self.worker_results ) - logger.info( - f"{self.__class__.__name__} stopped after {walltime:3g}s " - f"and {n_eval_total} objective evaluations " - f"with global best {result.optimize_result[0].fval}." - ) - + if len(result.optimize_result): + logger.info( + f"{self.__class__.__name__} stopped after {walltime:3g}s " + f"and {n_eval_total} objective evaluations " + f"with global best {result.optimize_result[0].fval}." + ) + else: + logger.error( + f"{self.__class__.__name__} stopped after {walltime:3g}s " + f"and {n_eval_total} objective evaluations without producing " + "a result." + ) return result def _create_result(self, problem: Problem) -> pypesto.Result: @@ -345,25 +353,40 @@ def _create_result(self, problem: Problem) -> pypesto.Result: Creates an overall Result object from the results saved by the workers. """ # gather results from workers and delete temporary result files - result = None + result = pypesto.Result() + retry_after_sleep = True for worker_idx in range(self.num_workers): tmp_result_filename = SacessWorker.get_temp_result_filename( worker_idx, self._tmpdir ) + tmp_result = None try: tmp_result = read_result( tmp_result_filename, problem=False, optimize=True ) except FileNotFoundError: # wait and retry, maybe the file wasn't found due to some filesystem latency issues - time.sleep(5) - tmp_result = read_result( - tmp_result_filename, problem=False, optimize=True - ) + if retry_after_sleep: + time.sleep(5) + # waiting once is enough - don't wait for every worker + retry_after_sleep = False + + try: + tmp_result = read_result( + tmp_result_filename, problem=False, optimize=True + ) + except FileNotFoundError: + logger.error( + f"Worker {worker_idx} did not produce a result." + ) + continue + else: + logger.error( + f"Worker {worker_idx} did not produce a result." + ) + continue - if result is None: - result = tmp_result - else: + if tmp_result: result.optimize_result.append( tmp_result.optimize_result, sort=False, @@ -375,7 +398,8 @@ def _create_result(self, problem: Problem) -> pypesto.Result: filename = SacessWorker.get_temp_result_filename( worker_idx, self._tmpdir ) - os.remove(filename) + with suppress(FileNotFoundError): + os.remove(filename) # delete tmpdir if empty try: self._tmpdir.rmdir() @@ -397,6 +421,7 @@ class SacessManager: Attributes ---------- + _dim: Dimension of the optimization problem _num_workers: Number of workers _ess_options: ESS options for each worker _best_known_fx: Best objective value encountered so far @@ -410,6 +435,7 @@ class SacessManager: _rejection_threshold: Threshold for relative objective improvements that incoming solutions have to pass to be accepted _lock: Lock for accessing shared state. + _terminate: Flag to signal termination of the SACESS run to workers _logger: A logger instance _options: Further optimizer hyperparameters. """ @@ -421,6 +447,7 @@ def __init__( dim: int, options: SacessOptions = None, ): + self._dim = dim self._options = options or SacessOptions() self._num_workers = len(ess_options) self._ess_options = [shmem_manager.dict(o) for o in ess_options] @@ -440,6 +467,7 @@ def __init__( self._worker_scores = shmem_manager.Array( "d", range(self._num_workers) ) + self._terminate = shmem_manager.Value("b", False) self._worker_comms = shmem_manager.Array("i", [0] * self._num_workers) self._lock = shmem_manager.RLock() self._logger = logging.getLogger() @@ -550,6 +578,16 @@ def submit_solution( ) self._rejections.value = 0 + def abort(self): + """Abort the SACESS run.""" + with self._lock: + self._terminate.value = True + + def aborted(self) -> bool: + """Whether this run was aborted.""" + with self._lock: + return self._terminate.value + class SacessWorker: """A SACESS worker. @@ -667,19 +705,42 @@ def run( f"(best: {self._best_known_fx}, " f"n_eval: {ess.evaluator.n_eval})." ) - - ess.history.finalize(exitflag=ess.exit_flag.name) - worker_result = SacessWorkerResult( - x=ess.x_best, - fx=ess.fx_best, - history=ess.history, - n_eval=ess.evaluator.n_eval, - n_iter=ess.n_iter, - exit_flag=ess.exit_flag, - ) + self._finalize(ess) + + def _finalize(self, ess: ESSOptimizer = None): + """Finalize the worker.""" + # Whatever happens here, we need to put something to the queue before + # returning to avoid deadlocks. + worker_result = None + if ess is not None: + try: + ess.history.finalize(exitflag=ess.exit_flag.name) + ess._report_final() + worker_result = SacessWorkerResult( + x=ess.x_best, + fx=ess.fx_best, + history=ess.history, + n_eval=ess.evaluator.n_eval, + n_iter=ess.n_iter, + exit_flag=ess.exit_flag, + ) + except Exception as e: + self._logger.exception( + f"Worker {self._worker_idx} failed to finalize: {e}" + ) + if worker_result is None: + # Create some dummy result + worker_result = SacessWorkerResult( + x=np.full(self._manager._dim, np.nan), + fx=np.nan, + history=MemoryHistory(), + n_eval=0, + n_iter=0, + exit_flag=ESSExitFlag.ERROR, + ) self._manager._result_queue.put(worker_result) + self._logger.debug(f"Final configuration: {self._ess_kwargs}") - ess._report_final() def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer: """Run ESS.""" @@ -835,9 +896,22 @@ def _keep_going(self): f"Max walltime ({self._max_walltime_s}s) exceeded." ) return False - + # other reason for termination (some worker failed, ...) + if self._manager.aborted(): + self.exit_flag = ESSExitFlag.ERROR + self._logger.debug("Manager requested termination.") + return False return True + def abort(self): + """Send signal to abort.""" + self._logger.error(f"Worker {self._worker_idx} aborting.") + # signal to manager + self._manager.abort() + + self.exit_flag = ESSExitFlag.ERROR + self._finalize(None) + @staticmethod def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str: return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute()) @@ -853,15 +927,24 @@ def _run_worker( Helper function as entrypoint for sacess worker processes. """ - # different random seeds per process - np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32) - - # Forward log messages to the logging process - h = logging.handlers.QueueHandler(log_process_queue) - worker._logger = logging.getLogger(multiprocessing.current_process().name) - worker._logger.addHandler(h) + try: + # different random seeds per process + np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32) + + # Forward log messages to the logging process + h = logging.handlers.QueueHandler(log_process_queue) + worker._logger = logging.getLogger( + multiprocessing.current_process().name + ) + worker._logger.addHandler(h) - return worker.run(problem=problem, startpoint_method=startpoint_method) + return worker.run(problem=problem, startpoint_method=startpoint_method) + except Exception as e: + with suppress(Exception): + worker._logger.exception( + f"Worker {worker._worker_idx} failed: {e}" + ) + worker.abort() def get_default_ess_options( diff --git a/test/optimize/test_optimize.py b/test/optimize/test_optimize.py index 48ebdea55..3e6a5b4ed 100644 --- a/test/optimize/test_optimize.py +++ b/test/optimize/test_optimize.py @@ -18,6 +18,7 @@ import pypesto import pypesto.optimize as optimize +from pypesto import Objective from pypesto.optimize.ess import ( ESSOptimizer, FunctionEvaluatorMP, @@ -577,6 +578,40 @@ def test_ess_refset_repr(): ) +class FunctionOrError: + """Callable that raises an error every nth invocation.""" + + def __init__(self, fun, error_frequency=100): + self.counter = 0 + self.error_frequency = error_frequency + self.fun = fun + + def __call__(self, *args, **kwargs): + self.counter += 1 + if self.counter % self.error_frequency == 0: + raise RuntimeError("Intentional error.") + return self.fun(*args, **kwargs) + + +def test_sacess_worker_error(capsys): + """Check that SacessOptimizer does not hang if an error occurs on a worker.""" + objective = Objective( + fun=FunctionOrError(sp.optimize.rosen), grad=sp.optimize.rosen_der + ) + problem = pypesto.Problem( + objective=objective, lb=0 * np.ones((1, 2)), ub=1 * np.ones((1, 2)) + ) + sacess = SacessOptimizer( + num_workers=2, + max_walltime_s=2, + sacess_loglevel=logging.DEBUG, + ess_loglevel=logging.DEBUG, + ) + res = sacess.minimize(problem) + assert isinstance(res, pypesto.Result) + assert "Intentional error." in capsys.readouterr().err + + def test_scipy_integrated_grad(): integrated = True obj = rosen_for_sensi(max_sensi_order=2, integrated=integrated)["obj"]