Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve exception-handling in SacessOptimizer #1517

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ESSExitFlag(int, enum.Enum):
MAX_EVAL = -2
# Exited after exhausting wall-time budget
MAX_TIME = -3
# Termination because for other reason than exit criteria
ERROR = -99


class OptimizerFactory(Protocol):
Expand Down
158 changes: 120 additions & 38 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import time
from contextlib import suppress
from dataclasses import dataclass
from math import ceil, sqrt
from multiprocessing import get_context
Expand All @@ -20,6 +21,7 @@

import pypesto

from ... import MemoryHistory
from ...startpoint import StartpointMethod
from ...store.read_from_hdf5 import read_result
from ...store.save_to_hdf5 import write_result
Expand Down Expand Up @@ -331,12 +333,18 @@ def minimize(
n_eval_total = sum(
worker_result.n_eval for worker_result in self.worker_results
)
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)

if len(result.optimize_result):
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)
else:
logger.error(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations without producing "
"a result."
)
return result

def _create_result(self, problem: Problem) -> pypesto.Result:
Expand All @@ -345,25 +353,40 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
Creates an overall Result object from the results saved by the workers.
"""
# gather results from workers and delete temporary result files
result = None
result = pypesto.Result()
retry_after_sleep = True
for worker_idx in range(self.num_workers):
tmp_result_filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
tmp_result = None
try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
# wait and retry, maybe the file wasn't found due to some filesystem latency issues
time.sleep(5)
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
if retry_after_sleep:
time.sleep(5)
# waiting once is enough - don't wait for every worker
retry_after_sleep = False

try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue
else:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue

if result is None:
result = tmp_result
else:
if tmp_result:
result.optimize_result.append(
tmp_result.optimize_result,
sort=False,
Expand All @@ -375,7 +398,8 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
os.remove(filename)
with suppress(FileNotFoundError):
os.remove(filename)
# delete tmpdir if empty
try:
self._tmpdir.rmdir()
Expand All @@ -397,6 +421,7 @@ class SacessManager:

Attributes
----------
_dim: Dimension of the optimization problem
_num_workers: Number of workers
_ess_options: ESS options for each worker
_best_known_fx: Best objective value encountered so far
Expand All @@ -410,6 +435,7 @@ class SacessManager:
_rejection_threshold: Threshold for relative objective improvements that
incoming solutions have to pass to be accepted
_lock: Lock for accessing shared state.
_terminate: Flag to signal termination of the SACESS run to workers
_logger: A logger instance
_options: Further optimizer hyperparameters.
"""
Expand All @@ -421,6 +447,7 @@ def __init__(
dim: int,
options: SacessOptions = None,
):
self._dim = dim
self._options = options or SacessOptions()
self._num_workers = len(ess_options)
self._ess_options = [shmem_manager.dict(o) for o in ess_options]
Expand All @@ -440,6 +467,7 @@ def __init__(
self._worker_scores = shmem_manager.Array(
"d", range(self._num_workers)
)
self._terminate = shmem_manager.Value("b", False)
self._worker_comms = shmem_manager.Array("i", [0] * self._num_workers)
self._lock = shmem_manager.RLock()
self._logger = logging.getLogger()
Expand Down Expand Up @@ -550,6 +578,16 @@ def submit_solution(
)
self._rejections.value = 0

def abort(self):
"""Abort the SACESS run."""
with self._lock:
self._terminate.value = True

def aborted(self) -> bool:
"""Whether this run was aborted."""
with self._lock:
return self._terminate.value


class SacessWorker:
"""A SACESS worker.
Expand Down Expand Up @@ -641,7 +679,7 @@ def run(
ess = self._setup_ess(startpoint_method)

# run ESS until exit criteria are met, but start at least one iteration
while self._keep_going() or ess.n_iter == 0:
while self._keep_going(ess) or ess.n_iter == 0:
# perform one ESS iteration
ess._do_iteration()

Expand All @@ -667,19 +705,42 @@ def run(
f"(best: {self._best_known_fx}, "
f"n_eval: {ess.evaluator.n_eval})."
)

ess.history.finalize(exitflag=ess.exit_flag.name)
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
self._finalize(ess)

def _finalize(self, ess: ESSOptimizer = None):
"""Finalize the worker."""
# Whatever happens here, we need to put something to the queue before
# returning to avoid deadlocks.
worker_result = None
if ess is not None:
try:
ess.history.finalize(exitflag=ess.exit_flag.name)
ess._report_final()
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
except Exception as e:
self._logger.exception(
f"Worker {self._worker_idx} failed to finalize: {e}"
)
if worker_result is None:
# Create some dummy result
worker_result = SacessWorkerResult(
x=np.full(self._manager._dim, np.nan),
fx=np.nan,
history=MemoryHistory(),
n_eval=0,
n_iter=0,
exit_flag=ESSExitFlag.ERROR,
)
self._manager._result_queue.put(worker_result)

self._logger.debug(f"Final configuration: {self._ess_kwargs}")
ess._report_final()

def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer:
"""Run ESS."""
Expand Down Expand Up @@ -821,7 +882,7 @@ def replace_solution(refset: RefSet, x: np.ndarray, fx: float):
fx=fx,
)

def _keep_going(self):
def _keep_going(self, ess: ESSOptimizer) -> bool:
"""Check exit criteria.

Returns
Expand All @@ -830,14 +891,26 @@ def _keep_going(self):
"""
# elapsed time
if time.time() - self._start_time >= self._max_walltime_s:
self.exit_flag = ESSExitFlag.MAX_TIME
ess.exit_flag = ESSExitFlag.MAX_TIME
self._logger.debug(
f"Max walltime ({self._max_walltime_s}s) exceeded."
)
return False

# other reasons for termination (some worker failed, ...)
if self._manager.aborted():
ess.exit_flag = ESSExitFlag.ERROR
self._logger.debug("Manager requested termination.")
return False
return True

def abort(self):
"""Send signal to abort."""
self._logger.error(f"Worker {self._worker_idx} aborting.")
# signal to manager
self._manager.abort()

self._finalize(None)

@staticmethod
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())
Expand All @@ -853,15 +926,24 @@ def _run_worker(

Helper function as entrypoint for sacess worker processes.
"""
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(multiprocessing.current_process().name)
worker._logger.addHandler(h)
try:
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(
multiprocessing.current_process().name
)
worker._logger.addHandler(h)

return worker.run(problem=problem, startpoint_method=startpoint_method)
return worker.run(problem=problem, startpoint_method=startpoint_method)
except Exception as e:
with suppress(Exception):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to suppress the exception here since we are already in an except? Or do you want to make sure it is definitely aborting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It the logging fails for what ever reason, we don't care, but we need to ensure that the line below that block gets executed.

worker._logger.exception(
f"Worker {worker._worker_idx} failed: {e}"
)
worker.abort()


def get_default_ess_options(
Expand Down
35 changes: 35 additions & 0 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pypesto
import pypesto.optimize as optimize
from pypesto import Objective
from pypesto.optimize.ess import (
ESSOptimizer,
FunctionEvaluatorMP,
Expand Down Expand Up @@ -577,6 +578,40 @@ def test_ess_refset_repr():
)


class FunctionOrError:
"""Callable that raises an error every nth invocation."""

def __init__(self, fun, error_frequency=100):
self.counter = 0
self.error_frequency = error_frequency
self.fun = fun

def __call__(self, *args, **kwargs):
self.counter += 1
if self.counter % self.error_frequency == 0:
raise RuntimeError("Intentional error.")
return self.fun(*args, **kwargs)


def test_sacess_worker_error(capsys):
"""Check that SacessOptimizer does not hang if an error occurs on a worker."""
objective = Objective(
fun=FunctionOrError(sp.optimize.rosen), grad=sp.optimize.rosen_der
)
problem = pypesto.Problem(
objective=objective, lb=0 * np.ones((1, 2)), ub=1 * np.ones((1, 2))
)
sacess = SacessOptimizer(
num_workers=2,
max_walltime_s=2,
sacess_loglevel=logging.DEBUG,
ess_loglevel=logging.DEBUG,
)
res = sacess.minimize(problem)
assert isinstance(res, pypesto.Result)
assert "Intentional error." in capsys.readouterr().err


def test_scipy_integrated_grad():
integrated = True
obj = rosen_for_sensi(max_sensi_order=2, integrated=integrated)["obj"]
Expand Down