Skip to content

Commit

Permalink
Merge pull request #8 from rs-station/dw_fix
Browse files Browse the repository at this point in the history
fix multi wilson model
  • Loading branch information
kmdalton authored Aug 19, 2024
2 parents 3bfb92c + 6f96bc8 commit d8dbaaf
Showing 1 changed file with 57 additions and 7 deletions.
64 changes: 57 additions & 7 deletions abismal/surrogate_posterior/structure_factor/wilson.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MultiWilsonPrior(tfk.layers.Layer):
```
"""
def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.):
def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1., **kwargs):
"""
Parameters
----------
Expand All @@ -109,16 +109,27 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.):
correlations : list
An iterable of prior correlation coefficients between asus. Use
0.0 for root nodes.
reindexing_opts : list (optional)
reindexing_ops : list (optional)
Optionally provide a list of reindexing operator strings, one per asu.
sigma : float or tensor (optional)
Optionally provide an average intensity value for the prior.
If this is a tensor, it should have the combinded length of all
the asus in the rac.
"""
super().__init__()
self.rac = rac
super().__init__(**kwargs)
#Store these for config purposes
self._rac = rac.get_config()
self._parents = parents
self._correlations = correlations
self._reindexing_ops = reindexing_ops
self._sigma = sigma

self.centric = rac.centric
self.epsilon = rac.epsilon

parent_ids = []
is_root = []

for asu_id, rasu in enumerate(rac):
pa = parents[asu_id]
if pa == asu_id or pa < 0:
Expand All @@ -127,6 +138,7 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.):
r = 0.
else:
r = correlations[asu_id]

op = 'x,y,z'
if reindexing_ops is not None:
op = reindexing_ops[asu_id]
Expand All @@ -138,13 +150,20 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.):

if pa is None:
parent_id = -tf.ones(rasu.asu_size, rasu.miller_id.dtype)
is_root.append(
tf.ones(rasu.asu_size, dtype='bool')
)
else:
parent_id = rac._miller_ids(
pa * tf.ones_like(hkl[:,:1]),
hkl,
)
is_root.append(
tf.zeros(rasu.asu_size, dtype='bool')
)
parent_ids.append(parent_id)

self.is_root = tf.concat(is_root, axis=0)
self.parent_ids = tf.concat(parent_ids, axis=0)
self.sigma = sigma

Expand All @@ -158,13 +177,32 @@ def __init__(self, rac, parents, correlations, reindexing_ops=None, sigma=1.):
tf.sqrt(0.5 * rac.epsilon * sigma * (1. - tf.square(self.r))),
)
self.has_parent = self.parent_ids >= 0
self.parent_ids = tf.where(self.has_parent, self.parent_ids, tf.range(self.rac.asu_size))
self.parent_ids = tf.where(self.has_parent, self.parent_ids, tf.range(rac.asu_size, dtype=tf.int32))
self.built = True #This is always true

self.p_centric = centric_wilson(self.epsilon, sigma)
self.p_acentric = acentric_wilson(self.epsilon, sigma)

@property
def rac(self):
return ReciprocalASUCollection.from_config(self._rac)

def get_config(self):
config = super().get_config()
config.update({
'rac' : self.rac,
'parents' : self._parents,
'correlations' : self._correlations,
'reindexing_ops' : self._reindexing_ops,
'sigma' : self.sigma,
})
return config

def mean(self):
"""
This is only for initialization of the surrogate!
"""
return WilsonPrior(self.rac.centric, self.rac.epsilon, self.sigma).mean()
return WilsonPrior(self.rac, self.sigma).mean()

def log_prob(self, z):
scale = self.scale #This is precomputed
Expand All @@ -174,11 +212,23 @@ def log_prob(self, z):
loc = tf.where(self.has_parent, loc, 0.)

ll = tf.where(
self.rac.centric,
self.centric,
FoldedNormal(loc, scale).log_prob(z),
Rice(loc, scale).log_prob(z),
)
wilson_p = tf.where(
self.centric,
self.p_centric.log_prob(z),
self.p_acentric.log_prob(z),
)

ll = tf.where(
self.is_root,
wilson_p,
ll,
)

return ll



0 comments on commit d8dbaaf

Please sign in to comment.