Skip to content

Commit

Permalink
Register the overloads added by CustomDist in worker processes (#7241)
Browse files Browse the repository at this point in the history
  • Loading branch information
EliasRas authored Dec 3, 2024
1 parent 7c369c8 commit 3f3aeb9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
41 changes: 40 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
from pymc.distributions.distribution import _support_point
from pymc.logprob.abstract import _icdf, _logcdf, _logprob
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
Expand Down Expand Up @@ -346,11 +349,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
# main process and our worker functions
_progress = manager.dict()

# check if model contains CustomDistributions defined without dist argument
custom_methods = _find_custom_dist_dispatch_methods(params[3])

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}

with ProcessPoolExecutor(max_workers=cores) as executor:
with ProcessPoolExecutor(
max_workers=cores,
initializer=_register_custom_methods,
initargs=(custom_methods,),
) as executor:
for c in range(chains): # iterate over the jobs we need to run
# set visible false so we don't have a lot of bars all at once:
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
Expand Down Expand Up @@ -383,3 +393,32 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
)

return tuple(cloudpickle.loads(r.result()) for r in done)


def _find_custom_dist_dispatch_methods(model):
custom_methods = {}
for rv in model.basic_RVs:
rv_type = rv.owner.op
cls = type(rv_type)
if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV):
custom_methods[cloudpickle.dumps(cls)] = (
cloudpickle.dumps(_logprob.registry.get(cls, None)),
cloudpickle.dumps(_logcdf.registry.get(cls, None)),
cloudpickle.dumps(_icdf.registry.get(cls, None)),
cloudpickle.dumps(_support_point.registry.get(cls, None)),
)

return custom_methods


def _register_custom_methods(custom_methods):
for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items():
cls = cloudpickle.loads(cls)
if logprob is not None:
_logprob.register(cls, cloudpickle.loads(logprob))
if logcdf is not None:
_logcdf.register(cls, cloudpickle.loads(logcdf))
if icdf is not None:
_icdf.register(cls, cloudpickle.loads(icdf))
if support_point is not None:
_support_point.register(cls, cloudpickle.loads(support_point))
15 changes: 15 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ def test_unobserved_categorical(self):

assert np.all(np.median(trace["mu"], axis=0) == [1, 2])

def test_parallel_custom(self):
def _logp(value, mu):
return -((value - mu) ** 2)

def _random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist)
pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2])
pm.sample_smc(draws=6, cores=2)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
Expand Down

0 comments on commit 3f3aeb9

Please sign in to comment.