Skip to content

Commit

Permalink
Merge branch 'develop' into visualize_obs_mapping_ax_flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
Doresic authored Dec 2, 2024
2 parents 72d512d + 010e93a commit c1ed00c
Show file tree
Hide file tree
Showing 19 changed files with 439 additions and 178 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
/pypesto/select/ @dilpath
/pypesto/startpoint/ @PaulJonasJost
/pypesto/store/ @PaulJonasJost
/pypesto/visualize/ @stephanmg

# Tests
/test/base/ @PaulJonasJost @vwiela
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:

- name: Run tests
timeout-minutes: 35
run: tox -e petab
run: tox -e petab && tox e -e petab -- pip uninstall -y amici
env:
CC: clang
CXX: clang++
Expand Down
32 changes: 32 additions & 0 deletions .github/workflows/publish_dockerhub.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Build and Push Docker Image

on:
push:
branches:
- main
paths:
- 'docker/**'
- '.github/workflows/docker-publish.yml'
workflow_dispatch:

jobs:
build-and-push:
runs-on: ubuntu-latest

steps:
- name: Check out the repository
uses: actions/checkout@v4

- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

- name: Build and tag the Docker image
run: |
docker build -t ICB_DCM/pypesto:latest -f docker/Dockerfile .
- name: Push the Docker image to Docker Hub
run: |
docker push ICB_DCM/pypesto:latest
161 changes: 92 additions & 69 deletions doc/using_pypesto.bib

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ESSExitFlag(int, enum.Enum):
MAX_EVAL = -2
# Exited after exhausting wall-time budget
MAX_TIME = -3
# Termination because for other reason than exit criteria
ERROR = -99


class OptimizerFactory(Protocol):
Expand Down Expand Up @@ -401,7 +403,6 @@ def _create_result(self) -> pypesto.Result:
for i, optimizer_result in enumerate(self.local_solutions):
i_result += 1
optimizer_result.id = f"Local solution {i}"
optimizer_result.optimizer = str(self.local_optimizer)
result.optimize_result.append(optimizer_result)

if self._result_includes_refset:
Expand Down
158 changes: 120 additions & 38 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import time
from contextlib import suppress
from dataclasses import dataclass
from math import ceil, sqrt
from multiprocessing import get_context
Expand All @@ -20,6 +21,7 @@

import pypesto

from ... import MemoryHistory
from ...startpoint import StartpointMethod
from ...store.read_from_hdf5 import read_result
from ...store.save_to_hdf5 import write_result
Expand Down Expand Up @@ -331,12 +333,18 @@ def minimize(
n_eval_total = sum(
worker_result.n_eval for worker_result in self.worker_results
)
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)

if len(result.optimize_result):
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)
else:
logger.error(
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations without producing "
"a result."
)
return result

def _create_result(self, problem: Problem) -> pypesto.Result:
Expand All @@ -345,25 +353,40 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
Creates an overall Result object from the results saved by the workers.
"""
# gather results from workers and delete temporary result files
result = None
result = pypesto.Result()
retry_after_sleep = True
for worker_idx in range(self.num_workers):
tmp_result_filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
tmp_result = None
try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
# wait and retry, maybe the file wasn't found due to some filesystem latency issues
time.sleep(5)
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
if retry_after_sleep:
time.sleep(5)
# waiting once is enough - don't wait for every worker
retry_after_sleep = False

try:
tmp_result = read_result(
tmp_result_filename, problem=False, optimize=True
)
except FileNotFoundError:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue
else:
logger.error(
f"Worker {worker_idx} did not produce a result."
)
continue

if result is None:
result = tmp_result
else:
if tmp_result:
result.optimize_result.append(
tmp_result.optimize_result,
sort=False,
Expand All @@ -375,7 +398,8 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
os.remove(filename)
with suppress(FileNotFoundError):
os.remove(filename)
# delete tmpdir if empty
try:
self._tmpdir.rmdir()
Expand All @@ -397,6 +421,7 @@ class SacessManager:
Attributes
----------
_dim: Dimension of the optimization problem
_num_workers: Number of workers
_ess_options: ESS options for each worker
_best_known_fx: Best objective value encountered so far
Expand All @@ -410,6 +435,7 @@ class SacessManager:
_rejection_threshold: Threshold for relative objective improvements that
incoming solutions have to pass to be accepted
_lock: Lock for accessing shared state.
_terminate: Flag to signal termination of the SACESS run to workers
_logger: A logger instance
_options: Further optimizer hyperparameters.
"""
Expand All @@ -421,6 +447,7 @@ def __init__(
dim: int,
options: SacessOptions = None,
):
self._dim = dim
self._options = options or SacessOptions()
self._num_workers = len(ess_options)
self._ess_options = [shmem_manager.dict(o) for o in ess_options]
Expand All @@ -440,6 +467,7 @@ def __init__(
self._worker_scores = shmem_manager.Array(
"d", range(self._num_workers)
)
self._terminate = shmem_manager.Value("b", False)
self._worker_comms = shmem_manager.Array("i", [0] * self._num_workers)
self._lock = shmem_manager.RLock()
self._logger = logging.getLogger()
Expand Down Expand Up @@ -550,6 +578,16 @@ def submit_solution(
)
self._rejections.value = 0

def abort(self):
"""Abort the SACESS run."""
with self._lock:
self._terminate.value = True

def aborted(self) -> bool:
"""Whether this run was aborted."""
with self._lock:
return self._terminate.value


class SacessWorker:
"""A SACESS worker.
Expand Down Expand Up @@ -641,7 +679,7 @@ def run(
ess = self._setup_ess(startpoint_method)

# run ESS until exit criteria are met, but start at least one iteration
while self._keep_going() or ess.n_iter == 0:
while self._keep_going(ess) or ess.n_iter == 0:
# perform one ESS iteration
ess._do_iteration()

Expand All @@ -667,19 +705,42 @@ def run(
f"(best: {self._best_known_fx}, "
f"n_eval: {ess.evaluator.n_eval})."
)

ess.history.finalize(exitflag=ess.exit_flag.name)
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
self._finalize(ess)

def _finalize(self, ess: ESSOptimizer = None):
"""Finalize the worker."""
# Whatever happens here, we need to put something to the queue before
# returning to avoid deadlocks.
worker_result = None
if ess is not None:
try:
ess.history.finalize(exitflag=ess.exit_flag.name)
ess._report_final()
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
except Exception as e:
self._logger.exception(
f"Worker {self._worker_idx} failed to finalize: {e}"
)
if worker_result is None:
# Create some dummy result
worker_result = SacessWorkerResult(
x=np.full(self._manager._dim, np.nan),
fx=np.nan,
history=MemoryHistory(),
n_eval=0,
n_iter=0,
exit_flag=ESSExitFlag.ERROR,
)
self._manager._result_queue.put(worker_result)

self._logger.debug(f"Final configuration: {self._ess_kwargs}")
ess._report_final()

def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer:
"""Run ESS."""
Expand Down Expand Up @@ -821,7 +882,7 @@ def replace_solution(refset: RefSet, x: np.ndarray, fx: float):
fx=fx,
)

def _keep_going(self):
def _keep_going(self, ess: ESSOptimizer) -> bool:
"""Check exit criteria.
Returns
Expand All @@ -830,14 +891,26 @@ def _keep_going(self):
"""
# elapsed time
if time.time() - self._start_time >= self._max_walltime_s:
self.exit_flag = ESSExitFlag.MAX_TIME
ess.exit_flag = ESSExitFlag.MAX_TIME
self._logger.debug(
f"Max walltime ({self._max_walltime_s}s) exceeded."
)
return False

# other reasons for termination (some worker failed, ...)
if self._manager.aborted():
ess.exit_flag = ESSExitFlag.ERROR
self._logger.debug("Manager requested termination.")
return False
return True

def abort(self):
"""Send signal to abort."""
self._logger.error(f"Worker {self._worker_idx} aborting.")
# signal to manager
self._manager.abort()

self._finalize(None)

@staticmethod
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())
Expand All @@ -853,15 +926,24 @@ def _run_worker(
Helper function as entrypoint for sacess worker processes.
"""
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(multiprocessing.current_process().name)
worker._logger.addHandler(h)
try:
# different random seeds per process
np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32)

# Forward log messages to the logging process
h = logging.handlers.QueueHandler(log_process_queue)
worker._logger = logging.getLogger(
multiprocessing.current_process().name
)
worker._logger.addHandler(h)

return worker.run(problem=problem, startpoint_method=startpoint_method)
return worker.run(problem=problem, startpoint_method=startpoint_method)
except Exception as e:
with suppress(Exception):
worker._logger.exception(
f"Worker {worker._worker_idx} failed: {e}"
)
worker.abort()


def get_default_ess_options(
Expand Down
Loading

0 comments on commit c1ed00c

Please sign in to comment.