Skip to content

Commit

Permalink
Improve exception-handling in SacessOptimizer
Browse files Browse the repository at this point in the history
Previously, uncaught exceptions in SacessOptimizer worker processes would have resulted in deadlocks in `SacessOptimizer.minimize`.

These changes will (in most cases) prevent deadlocks and other errors due to missing results from individual workers.
Furthermore, this will lead to termination of minimize() relatively soon after an error occurred on some worker - not only after some other exit criterion is met.

Closes #1512.
  • Loading branch information
dweindl committed Nov 25, 2024
1 parent 4611586 commit 4b79df7
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 35 deletions.
2 changes: 2 additions & 0 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ESSExitFlag(int, enum.Enum):
MAX_EVAL = -2
# Exited after exhausting wall-time budget
MAX_TIME = -3
# Termination because for other reason than exit criteria
ERROR = -99


class OptimizerFactory(Protocol):
Expand Down
153 changes: 118 additions & 35 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import time
from contextlib import suppress
from dataclasses import dataclass
from math import ceil, sqrt
from multiprocessing import get_context
Expand All @@ -20,6 +21,7 @@

import pypesto

from ... import MemoryHistory
from ...startpoint import StartpointMethod
from ...store.read_from_hdf5 import read_result
from ...store.save_to_hdf5 import write_result
Expand Down Expand Up @@ -331,12 +333,18 @@ def minimize(
n_eval_total = sum(
worker_result.n_eval for worker_result in self.worker_results
)
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)

if len(result.optimize_result):
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)
else:
logger.error(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations without producing "
"a result."
)
return result

def _create_result(self, problem: Problem) -> pypesto.Result:
Expand All @@ -345,25 +353,40 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
Creates an overall Result object from the results saved by the workers.
"""
# gather results from workers and delete temporary result files
result = None
result = pypesto.Result()
retry_after_sleep = True
for worker_idx in range(self.num_workers):
tmp_result_filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
tmp_result = None
try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
# wait and retry, maybe the file wasn't found due to some filesystem latency issues
time.sleep(5)
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
if retry_after_sleep:
time.sleep(5)
# waiting once is enough - don't wait for every worker
retry_after_sleep = False

try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue
else:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue

if result is None:
result = tmp_result
else:
if tmp_result:
result.optimize_result.append(
tmp_result.optimize_result,
sort=False,
Expand All @@ -375,7 +398,8 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
os.remove(filename)
with suppress(FileNotFoundError):
os.remove(filename)
# delete tmpdir if empty
try:
self._tmpdir.rmdir()
Expand All @@ -397,6 +421,7 @@ class SacessManager:
Attributes
----------
_dim: Dimension of the optimization problem
_num_workers: Number of workers
_ess_options: ESS options for each worker
_best_known_fx: Best objective value encountered so far
Expand All @@ -410,6 +435,7 @@ class SacessManager:
_rejection_threshold: Threshold for relative objective improvements that
incoming solutions have to pass to be accepted
_lock: Lock for accessing shared state.
_terminate: Flag to signal termination of the SACESS run to workers
_logger: A logger instance
_options: Further optimizer hyperparameters.
"""
Expand All @@ -421,6 +447,7 @@ def __init__(
dim: int,
options: SacessOptions = None,
):
self._dim = dim
self._options = options or SacessOptions()
self._num_workers = len(ess_options)
self._ess_options = [shmem_manager.dict(o) for o in ess_options]
Expand All @@ -440,6 +467,7 @@ def __init__(
self._worker_scores = shmem_manager.Array(
"d", range(self._num_workers)
)
self._terminate = shmem_manager.Value("b", False)
self._worker_comms = shmem_manager.Array("i", [0] * self._num_workers)
self._lock = shmem_manager.RLock()
self._logger = logging.getLogger()
Expand Down Expand Up @@ -550,6 +578,16 @@ def submit_solution(
)
self._rejections.value = 0

def abort(self):
"""Abort the SACESS run."""
with self._lock:
self._terminate.value = True

def aborted(self) -> bool:
"""Whether this run was aborted."""
with self._lock:
return self._terminate.value


class SacessWorker:
"""A SACESS worker.
Expand Down Expand Up @@ -667,19 +705,42 @@ def run(
f"(best: {self._best_known_fx}, "
f"n_eval: {ess.evaluator.n_eval})."
)

ess.history.finalize(exitflag=ess.exit_flag.name)
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
self._finalize(ess)

def _finalize(self, ess: ESSOptimizer = None):
"""Finalize the worker."""
# Whatever happens here, we need to put something to the queue before
# returning to avoid deadlocks.
worker_result = None
if ess is not None:
try:
ess.history.finalize(exitflag=ess.exit_flag.name)
ess._report_final()
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
except Exception as e:
self._logger.exception(
f"Worker {self._worker_idx} failed to finalize: {e}"
)
if worker_result is None:
# Create some dummy result
worker_result = SacessWorkerResult(
x=np.full(self._manager._dim, np.nan),
fx=np.nan,
history=MemoryHistory(),
n_eval=0,
n_iter=0,
exit_flag=ESSExitFlag.ERROR,
)
self._manager._result_queue.put(worker_result)

self._logger.debug(f"Final configuration: {self._ess_kwargs}")
ess._report_final()

def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer:
"""Run ESS."""
Expand Down Expand Up @@ -835,9 +896,22 @@ def _keep_going(self):
f"Max walltime ({self._max_walltime_s}s) exceeded."
)
return False

# other reason for termination (some worker failed, ...)
if self._manager.aborted():
self.exit_flag = ESSExitFlag.ERROR
self._logger.debug("Manager requested termination.")
return False
return True

def abort(self):
"""Send signal to abort."""
self._logger.error(f"Worker {self._worker_idx} aborting.")
# signal to manager
self._manager.abort()

self.exit_flag = ESSExitFlag.ERROR
self._finalize(None)

@staticmethod
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())
Expand All @@ -853,15 +927,24 @@ def _run_worker(
Helper function as entrypoint for sacess worker processes.
"""
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(multiprocessing.current_process().name)
worker._logger.addHandler(h)
try:
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(
multiprocessing.current_process().name
)
worker._logger.addHandler(h)

return worker.run(problem=problem, startpoint_method=startpoint_method)
return worker.run(problem=problem, startpoint_method=startpoint_method)
except Exception as e:
with suppress(Exception):
worker._logger.exception(
f"Worker {worker._worker_idx} failed: {e}"
)
worker.abort()


def get_default_ess_options(
Expand Down
35 changes: 35 additions & 0 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pypesto
import pypesto.optimize as optimize
from pypesto import Objective
from pypesto.optimize.ess import (
ESSOptimizer,
FunctionEvaluatorMP,
Expand Down Expand Up @@ -577,6 +578,40 @@ def test_ess_refset_repr():
)


class FunctionOrError:
"""Callable that raises an error every nth invocation."""

def __init__(self, fun, error_frequency=100):
self.counter = 0
self.error_frequency = error_frequency
self.fun = fun

def __call__(self, *args, **kwargs):
self.counter += 1
if self.counter % self.error_frequency == 0:
raise RuntimeError("Intentional error.")
return self.fun(*args, **kwargs)


def test_sacess_worker_error(capsys):
"""Check that SacessOptimizer does not hang if an error occurs on a worker."""
objective = Objective(
fun=FunctionOrError(sp.optimize.rosen), grad=sp.optimize.rosen_der
)
problem = pypesto.Problem(
objective=objective, lb=0 * np.ones((1, 2)), ub=1 * np.ones((1, 2))
)
sacess = SacessOptimizer(
num_workers=2,
max_walltime_s=2,
sacess_loglevel=logging.DEBUG,
ess_loglevel=logging.DEBUG,
)
res = sacess.minimize(problem)
assert isinstance(res, pypesto.Result)
assert "Intentional error." in capsys.readouterr().err


def test_scipy_integrated_grad():
integrated = True
obj = rosen_for_sensi(max_sensi_order=2, integrated=integrated)["obj"]
Expand Down

0 comments on commit 4b79df7

Please sign in to comment.