diff --git a/abismal/surrogate_posterior/structure_factor/wilson.py b/abismal/surrogate_posterior/structure_factor/wilson.py index 29b9141..2da4dfd 100644 --- a/abismal/surrogate_posterior/structure_factor/wilson.py +++ b/abismal/surrogate_posterior/structure_factor/wilson.py @@ -97,7 +97,7 @@ class MultiWilsonPrior(tfk.layers.Layer): ``` """ - def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.): + def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1., **kwargs): """ Parameters ---------- @@ -109,16 +109,27 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.): correlations : list An iterable of prior correlation coefficients between asus. Use 0.0 for root nodes. - reindexing_opts : list (optional) + reindexing_ops : list (optional) Optionally provide a list of reindexing operator strings, one per asu. sigma : float or tensor (optional) Optionally provide an average intensity value for the prior. If this is a tensor, it should have the combinded length of all the asus in the rac. """ - super().__init__() - self.rac = rac + super().__init__(**kwargs) + #Store these for config purposes + self._rac = rac.get_config() + self._parents = parents + self._correlations = correlations + self._reindexing_ops = reindexing_ops + self._sigma = sigma + + self.centric = rac.centric + self.epsilon = rac.epsilon + parent_ids = [] + is_root = [] + for asu_id, rasu in enumerate(rac): pa = parents[asu_id] if pa == asu_id or pa < 0: @@ -127,6 +138,7 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.): r = 0. else: r = correlations[asu_id] + op = 'x,y,z' if reindexing_ops is not None: op = reindexing_ops[asu_id] @@ -138,13 +150,20 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.): if pa is None: parent_id = -tf.ones(rasu.asu_size, rasu.miller_id.dtype) + is_root.append( + tf.ones(rasu.asu_size, dtype='bool') + ) else: parent_id = rac._miller_ids( pa * tf.ones_like(hkl[:,:1]), hkl, ) + is_root.append( + tf.zeros(rasu.asu_size, dtype='bool') + ) parent_ids.append(parent_id) + self.is_root = tf.concat(is_root, axis=0) self.parent_ids = tf.concat(parent_ids, axis=0) self.sigma = sigma @@ -158,13 +177,32 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.): tf.sqrt(0.5 * rac.epsilon * sigma * (1. - tf.square(self.r))), ) self.has_parent = self.parent_ids >= 0 - self.parent_ids = tf.where(self.has_parent, self.parent_ids, tf.range(self.rac.asu_size)) + self.parent_ids = tf.where(self.has_parent, self.parent_ids, tf.range(rac.asu_size, dtype=tf.int32)) + self.built = True #This is always true + + self.p_centric = centric_wilson(self.epsilon, sigma) + self.p_acentric = acentric_wilson(self.epsilon, sigma) + + @property + def rac(self): + return ReciprocalASUCollection.from_config(self._rac) + + def get_config(self): + config = super().get_config() + config.update({ + 'rac' : self.rac, + 'parents' : self._parents, + 'correlations' : self._correlations, + 'reindexing_ops' : self._reindexing_ops, + 'sigma' : self.sigma, + }) + return config def mean(self): """ This is only for initialization of the surrogate! """ - return WilsonPrior(self.rac.centric, self.rac.epsilon, self.sigma).mean() + return WilsonPrior(self.rac, self.sigma).mean() def log_prob(self, z): scale = self.scale #This is precomputed @@ -174,11 +212,23 @@ def log_prob(self, z): loc = tf.where(self.has_parent, loc, 0.) ll = tf.where( - self.rac.centric, + self.centric, FoldedNormal(loc, scale).log_prob(z), Rice(loc, scale).log_prob(z), ) + wilson_p = tf.where( + self.centric, + self.p_centric.log_prob(z), + self.p_acentric.log_prob(z), + ) + + ll = tf.where( + self.is_root, + wilson_p, + ll, + ) return ll +