Skip to content

Commit

Permalink
delete parallel stuff - moved to own branch
Browse files Browse the repository at this point in the history
  • Loading branch information
aradermacher committed Dec 18, 2024
1 parent 601aab1 commit fffc69e
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 118 deletions.
107 changes: 49 additions & 58 deletions probeye/inference/bias/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@
from probeye.subroutines import stream_to_logger
from probeye.subroutines import print_dict_in_rows

from multiprocessing import Pool # pickling problem

# from multiprocessing.pool import ThreadPool as Pool # no pickling needed but no time effect
import os

os.environ["OMP_NUM_THREADS"] = "1"
# imports only needed for type hints
if TYPE_CHECKING: # pragma: no cover
from probeye.definition.inverse_problem import InverseProblem


class EmbeddedMCISolver(EmceeSolver):
Expand Down Expand Up @@ -200,7 +197,6 @@ def run(
n_steps: int = 1000,
n_initial_steps: int = 100,
true_values: Optional[dict] = None,
n_processes: int = Pool()._processes,
**kwargs,
) -> az.data.inference_data.InferenceData:
"""
Expand All @@ -217,8 +213,6 @@ def run(
Number of steps for initial (burn-in) sampling.
true_values
True parameter values, if known.
n_processes
Number of processes to use for parallel sampling.
kwargs
Additional key-word arguments channeled to emcee.EnsembleSampler.
Expand Down Expand Up @@ -262,7 +256,6 @@ def run(
# ............................................................................ #
# Pre-process #
# ............................................................................ #
global logprob

def logprob(x):
# Skip loglikelihood evaluation if logprior is equal
Expand All @@ -276,59 +269,57 @@ def logprob(x):

logger.debug("Setting up EnsembleSampler")

with Pool(processes=n_processes) as pool:
logger.info(f"parallel sampling using multiprocessing with {pool}")
self.sampler = emcee.EnsembleSampler(
nwalkers=n_walkers,
ndim=self.problem.n_latent_prms_dim,
log_prob_fn=logprob,
pool=pool,
**kwargs,
)
self.sampler = emcee.EnsembleSampler(
nwalkers=n_walkers,
ndim=self.problem.n_latent_prms_dim,
log_prob_fn=logprob,
**kwargs,
)

if self.seed is not None:
random.seed(self.seed)
self.sampler.random_state = np.random.mtrand.RandomState(self.seed)
if self.seed is not None:
random.seed(self.seed)
self.sampler.random_state = np.random.mtrand.RandomState(self.seed)

# ............................................................................ #
# Initial sampling, burn-in: used to avoid a poor starting point #
# ............................................................................ #
# ............................................................................ #
# Initial sampling, burn-in: used to avoid a poor starting point #
# ............................................................................ #

logger.debug("Starting sampling (initial + main)")
start = time.time()
state = self.sampler.run_mcmc(
initial_state=sampling_initial_positions,
nsteps=n_initial_steps,
progress=self.show_progress,
)
self.sampler.reset()
logger.debug("Starting sampling (initial + main)")
start = time.time()
state = self.sampler.run_mcmc(
initial_state=sampling_initial_positions,
nsteps=n_initial_steps,
progress=self.show_progress,
)
self.sampler.reset()

# ............................................................................ #
# Sampling of the posterior #
# ............................................................................ #
self.sampler.run_mcmc(
initial_state=state, nsteps=n_steps, progress=self.show_progress
)
end = time.time()
runtime_str = pretty_time_delta(end - start)
logger.info(
f"Sampling of the posterior distribution completed: {n_steps} steps and "
f"{n_walkers} walkers."
# ............................................................................ #
# Sampling of the posterior #
# ............................................................................ #
self.sampler.run_mcmc(
initial_state=state, nsteps=n_steps, progress=self.show_progress
)
end = time.time()
runtime_str = pretty_time_delta(end - start)
logger.info(
f"Sampling of the posterior distribution completed: {n_steps} steps and "
f"{n_walkers} walkers."
)
logger.info(f"Total run-time (including initial sampling): {runtime_str}.")
logger.info("")
logger.info("Summary of sampling results (emcee)")
posterior_samples = self.sampler.get_chain(flat=True)
with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore
self.summary = self.emcee_summary(
posterior_samples, true_values=true_values
)
logger.info(f"Total run-time (including initial sampling): {runtime_str}.")
logger.info("")
logger.info("Summary of sampling results (emcee)")
posterior_samples = self.sampler.get_chain(flat=True)
with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore
self.summary = self.emcee_summary(
posterior_samples, true_values=true_values
)
logger.info("") # empty line for visual buffer
self.raw_results = self.sampler

# translate the results to a common data structure and return it
self.var_names = self.problem.get_theta_names(tex=True, components=True)
inference_data = az.from_emcee(self.sampler, var_names=self.var_names)
logger.info("") # empty line for visual buffer
self.raw_results = self.sampler

# translate the results to a common data structure and return it
self.var_names = self.problem.get_theta_names(tex=True, components=True)
inference_data = az.from_emcee(self.sampler, var_names=self.var_names)

return inference_data

def restart_run(self, state, n_steps):
Expand Down
104 changes: 44 additions & 60 deletions probeye/inference/emcee/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@
from probeye.subroutines import print_dict_in_rows
from probeye.subroutines import extract_true_values

from multiprocessing import Pool # pickling problem

# from multiprocessing.pool import ThreadPool as Pool # no pickling needed but no time effect
import os

os.environ["OMP_NUM_THREADS"] = "1"


# imports only needed for type hints
if TYPE_CHECKING: # pragma: no cover
from probeye.definition.inverse_problem import InverseProblem
Expand Down Expand Up @@ -157,7 +149,6 @@ def run(
n_steps: int = 1000,
n_initial_steps: int = 100,
true_values: Optional[dict] = None,
n_processes: int = Pool()._processes,
**kwargs,
) -> az.data.inference_data.InferenceData:
"""
Expand All @@ -174,8 +165,6 @@ def run(
Number of steps for initial (burn-in) sampling.
true_values
True parameter values, if known.
n_processes
Number of processes to use for parallel sampling.
kwargs
Additional key-word arguments channeled to emcee.EnsembleSampler.
Expand Down Expand Up @@ -220,8 +209,6 @@ def run(
# Pre-process #
# ............................................................................ #

global logprob

def logprob(x):
# Skip loglikelihood evaluation if logprior is equal
# to negative infinity
Expand All @@ -234,59 +221,56 @@ def logprob(x):

logger.debug("Setting up EnsembleSampler")

with Pool(processes=n_processes) as pool:
logger.info(f"parallel sampling using multiprocessing with {pool}")
sampler = emcee.EnsembleSampler(
nwalkers=n_walkers,
ndim=self.problem.n_latent_prms_dim,
log_prob_fn=logprob,
pool=pool,
**kwargs,
)
sampler = emcee.EnsembleSampler(
nwalkers=n_walkers,
ndim=self.problem.n_latent_prms_dim,
log_prob_fn=logprob,
**kwargs,
)

if self.seed is not None:
random.seed(self.seed)
sampler.random_state = np.random.mtrand.RandomState(self.seed)
if self.seed is not None:
random.seed(self.seed)
sampler.random_state = np.random.mtrand.RandomState(self.seed)

# ............................................................................ #
# Initial sampling, burn-in: used to avoid a poor starting point #
# ............................................................................ #
# ............................................................................ #
# Initial sampling, burn-in: used to avoid a poor starting point #
# ............................................................................ #

logger.debug("Starting sampling (initial + main)")
start = time.time()
state = sampler.run_mcmc(
initial_state=sampling_initial_positions,
nsteps=n_initial_steps,
progress=self.show_progress,
)
sampler.reset()
logger.debug("Starting sampling (initial + main)")
start = time.time()
state = sampler.run_mcmc(
initial_state=sampling_initial_positions,
nsteps=n_initial_steps,
progress=self.show_progress,
)
sampler.reset()

# ............................................................................ #
# Sampling of the posterior #
# ............................................................................ #
sampler.run_mcmc(
initial_state=state, nsteps=n_steps, progress=self.show_progress
)
end = time.time()
# ............................................................................ #
# Sampling of the posterior #
# ............................................................................ #
sampler.run_mcmc(
initial_state=state, nsteps=n_steps, progress=self.show_progress
)
end = time.time()

runtime_str = pretty_time_delta(end - start)
logger.info(
f"Sampling of the posterior distribution completed: {n_steps} steps and "
f"{n_walkers} walkers."
runtime_str = pretty_time_delta(end - start)
logger.info(
f"Sampling of the posterior distribution completed: {n_steps} steps and "
f"{n_walkers} walkers."
)
logger.info(f"Total run-time (including initial sampling): {runtime_str}.")
logger.info("")
logger.info("Summary of sampling results (emcee)")
posterior_samples = sampler.get_chain(flat=True)
with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore
self.summary = self.emcee_summary(
posterior_samples, true_values=true_values
)
logger.info(f"Total run-time (including initial sampling): {runtime_str}.")
logger.info("")
logger.info("Summary of sampling results (emcee)")
posterior_samples = sampler.get_chain(flat=True)
with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore
self.summary = self.emcee_summary(
posterior_samples, true_values=true_values
)
logger.info("") # empty line for visual buffer
self.raw_results = sampler
logger.info("") # empty line for visual buffer
self.raw_results = sampler

# translate the results to a common data structure and return it
var_names = self.problem.get_theta_names(tex=True, components=True)
inference_data = az.from_emcee(sampler, var_names=var_names)
# translate the results to a common data structure and return it
var_names = self.problem.get_theta_names(tex=True, components=True)
inference_data = az.from_emcee(sampler, var_names=var_names)

return inference_data

0 comments on commit fffc69e

Please sign in to comment.