From cf5c9d4a577269c117173cffe2f818125898ff83 Mon Sep 17 00:00:00 2001 From: zachjweiner Date: Sun, 5 Mar 2023 16:17:07 -0800 Subject: [PATCH 1/2] infer named parameter layout from dict-type initial_state --- src/emcee/ensemble.py | 153 +++++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 69 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index a83d0620..02ddac46 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -2,7 +2,7 @@ import warnings from itertools import count -from typing import Dict, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Sequence, Union import numpy as np @@ -15,11 +15,42 @@ __all__ = ["EnsembleSampler", "walkers_independent"] -try: - from collections.abc import Iterable -except ImportError: - # for py2.7, will be an Exception in 3.8 - from collections import Iterable +ParameterNamesT = Union[ + Sequence[str], Dict[str, Union[slice, int, Sequence[int]]] +] + + +def infer_dict_mapping(state): + i0 = 0 + param_slice_shape = {} + for key, val in state.items(): + val = np.asarray(val) + i1 = i0 + val.size + slc = slice(i0, i1) if val.size > 1 else i0 + param_slice_shape[key] = slc, val.shape + i0 = i1 + + return param_slice_shape + + +def array_to_dict(ary, param_slice_shape): + return { + key: ary[:, slc].reshape((-1,)+shape) + for key, (slc, shape) in param_slice_shape.items() + } + + +def array_to_list_of_dicts(ary, param_slice_shape): + # reshape adds a small amount of overhead; don't do it unless necessary + return [{ + key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc] + for key, (slc, shape) in param_slice_shape.items() + } for ary_i in ary] + + +def collapse_and_hstack(values, nwalkers=None): + shape = (nwalkers, -1) if nwalkers is not None else -1 + return np.hstack([np.asarray(val).reshape(shape) for val in values]) class EnsembleSampler(object): @@ -62,7 +93,8 @@ class EnsembleSampler(object): to accept a list of position vectors instead of just one. Note that ``pool`` will be ignored if this is ``True``. (default: ``False``) - parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]): + parameter_names (Union[Sequence[str], + Dict[str, Union[slice, int, Sequence[int]]]): names of individual parameters or groups of parameters. If specified, the ``log_prob_fn`` will recieve a dictionary of parameters, rather than a ``np.ndarray``. @@ -81,7 +113,7 @@ def __init__( backend=None, vectorize=False, blobs_dtype=None, - parameter_names: Optional[Union[Dict[str, int], List[str]]] = None, + parameter_names: Optional[ParameterNamesT] = None, # Deprecated... a=None, postargs=None, @@ -163,48 +195,39 @@ def __init__( # ``args`` and ``kwargs`` pickleable. self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs) - # Save the parameter names - self.params_are_named: bool = parameter_names is not None - if self.params_are_named: - assert isinstance(parameter_names, (list, dict)) - - # Don't support vectorizing yet - msg = "named parameters with vectorization unsupported for now" - assert not self.vectorize, msg - - # Check for duplicate names - dupes = set() - uniq = [] - for name in parameter_names: - if name not in dupes: - uniq.append(name) - dupes.add(name) - msg = f"duplicate paramters: {dupes}" - assert len(uniq) == len(parameter_names), msg - - if isinstance(parameter_names, list): - # Check for all named - msg = "name all parameters or set `parameter_names` to `None`" - assert len(parameter_names) == ndim, msg - # Convert a list to a dict - parameter_names: Dict[str, int] = { - name: i for i, name in enumerate(parameter_names) + if parameter_names is not None: + if isinstance(parameter_names, Sequence): + if len(parameter_names) != ndim: + raise ValueError( + f"`parameter_names` does not specify {ndim} names") + parameter_names = dict(zip(parameter_names, range(ndim))) + + indices = np.arange(ndim) + + try: + index_map = { + key: indices[slc] + for key, slc in parameter_names.items() } + indexed = collapse_and_hstack(index_map.values()) + except IndexError as err: + msg = "`parameter_names` specifies out-of-bounds element(s)" + raise ValueError(msg) from err - # Check not too many names - msg = "too many names" - assert len(parameter_names) <= ndim, msg - - # Check all indices appear - values = [ - v if isinstance(v, list) else [v] - for v in parameter_names.values() - ] - values = [item for sublist in values for item in sublist] - values = set(values) - msg = f"not all values appear -- set should be 0 to {ndim-1}" - assert values == set(np.arange(ndim)), msg - self.parameter_names = parameter_names + if len(indexed) != ndim: + raise ValueError( + "`parameter_names` does not specify indices for" + f" {ndim} parameters" + ) + if set(indexed) != set(indices): + raise ValueError( + "`parameter_names` does not specify indices" + f" 0 through {ndim-1}" + ) + + self.param_slice_shape = infer_dict_mapping(index_map) + else: + self.param_slice_shape = None @property def random_state(self): @@ -266,7 +289,8 @@ def sample( """Advance the chain as a generator Args: - initial_state (State or ndarray[nwalkers, ndim]): The initial + initial_state (State or ndarray[nwalkers, ndim] or + dict[str, float | np.ndarray[nwalkers. ...]]): The initial :class:`State` or positions of the walkers in the parameter space. iterations (Optional[int or NoneType]): The number of steps to generate. @@ -302,6 +326,12 @@ def sample( if iterations is None and store: raise ValueError("'store' must be False when 'iterations' is None") # Interpret the input as a walker state and check the dimensions. + if isinstance(initial_state, dict): + _state = {key: val[0] for key, val in initial_state.items()} + self.param_slice_shape = infer_dict_mapping(_state) + initial_state = collapse_and_hstack( + initial_state.values(), self.nwalkers) + state = State(initial_state, copy=True) state_shape = np.shape(state.coords) if state_shape != (self.nwalkers, self.ndim): @@ -472,8 +502,11 @@ def compute_log_prob(self, coords): raise ValueError("At least one parameter value was NaN") # If the parmaeters are named, then switch to dictionaries - if self.params_are_named: - p = ndarray_to_list_of_dicts(p, self.parameter_names) + if self.param_slice_shape: + if self.vectorize: + p = array_to_dict(p, self.param_slice_shape) + else: + p = array_to_list_of_dicts(p, self.param_slice_shape) # Run the log-probability calculations (optionally in parallel). if self.vectorize: @@ -664,21 +697,3 @@ def _scaled_cond(a): return np.inf c = b / bsum return np.linalg.cond(c.astype(float)) - - -def ndarray_to_list_of_dicts( - x: np.ndarray, key_map: Dict[str, Union[int, List[int]]] -) -> List[Dict[str, Union[np.number, np.ndarray]]]: - """ - A helper function to convert a ``np.ndarray`` into a list - of dictionaries of parameters. Used when parameters are named. - - Args: - x (np.ndarray): parameter array of shape ``(N, n_dim)``, where - ``N`` is an integer - key_map (Dict[str, Union[int, List[int]]): - - Returns: - list of dictionaries of parameters - """ - return [{key: xi[val] for key, val in key_map.items()} for xi in x] From ab63bf45f79b14279e4aa9beee6e52cf77f7b13d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Mar 2023 00:20:47 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/emcee/ensemble.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 02ddac46..6dbb2e06 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -35,17 +35,20 @@ def infer_dict_mapping(state): def array_to_dict(ary, param_slice_shape): return { - key: ary[:, slc].reshape((-1,)+shape) + key: ary[:, slc].reshape((-1,) + shape) for key, (slc, shape) in param_slice_shape.items() } def array_to_list_of_dicts(ary, param_slice_shape): # reshape adds a small amount of overhead; don't do it unless necessary - return [{ - key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc] - for key, (slc, shape) in param_slice_shape.items() - } for ary_i in ary] + return [ + { + key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc] + for key, (slc, shape) in param_slice_shape.items() + } + for ary_i in ary + ] def collapse_and_hstack(values, nwalkers=None): @@ -199,15 +202,15 @@ def __init__( if isinstance(parameter_names, Sequence): if len(parameter_names) != ndim: raise ValueError( - f"`parameter_names` does not specify {ndim} names") + f"`parameter_names` does not specify {ndim} names" + ) parameter_names = dict(zip(parameter_names, range(ndim))) indices = np.arange(ndim) try: index_map = { - key: indices[slc] - for key, slc in parameter_names.items() + key: indices[slc] for key, slc in parameter_names.items() } indexed = collapse_and_hstack(index_map.values()) except IndexError as err: @@ -330,7 +333,8 @@ def sample( _state = {key: val[0] for key, val in initial_state.items()} self.param_slice_shape = infer_dict_mapping(_state) initial_state = collapse_and_hstack( - initial_state.values(), self.nwalkers) + initial_state.values(), self.nwalkers + ) state = State(initial_state, copy=True) state_shape = np.shape(state.coords)