Skip to content

Commit

Permalink
EUSS vs ESS
Browse files Browse the repository at this point in the history
  • Loading branch information
minaskar committed May 22, 2024
1 parent ce046a0 commit 6ce7727
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
3 changes: 2 additions & 1 deletion pocomc/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def preconditioned_pcn(state_dict: dict,
logp2_val = logp2_val_new
else:
cnt += 1
if cnt >= n_steps * ((2.38 / n_dim**0.5) / sigma)**2.0:
if cnt >= n_steps * ((2.38 / n_dim**0.5) / sigma)**2.0 * (0.234 / np.mean(alpha)):
break

if i >= n_max:
Expand Down Expand Up @@ -396,6 +396,7 @@ def pcn(state_dict: dict,

# Adapt scale parameter using diminishing adaptation
sigma = np.abs(np.minimum(sigma + 1 / (i + 1)**0.75 * (np.mean(alpha) - 0.234), np.minimum(2.38 / n_dim**0.5, 0.99)))
#sigma = sigma + 1 / (i + 1)**0.75 * (np.mean(alpha) - 0.234)

# Update progress bar if available
if progress_bar is not None:
Expand Down
26 changes: 23 additions & 3 deletions pocomc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
flow=None,
train_config: dict = None,
precondition: bool = True,
dynamic: bool = False,
n_prior: int = None,
sample: str = None,
n_steps: int = None,
Expand Down Expand Up @@ -226,6 +227,8 @@ def __init__(self,
# Other
self.preconditioned = precondition

self.dynamic = dynamic

if sample is None:
self.sample = 'pcn'
elif sample in ['pcn']:
Expand All @@ -242,9 +245,9 @@ def __init__(self,

# Prior samples to draw
if n_prior is None:
self.n_prior = int(2 * (self.n_ess//self.n_active) * self.n_active)
self.n_prior = int(2 * np.maximum(self.n_ess//self.n_active, 1) * self.n_active)
else:
self.n_prior = int((n_prior/self.n_active) * self.n_active)
self.n_prior = int(np.maximum(n_prior/self.n_active, 1) * self.n_active)
self.prior_samples = None

self.logz = None
Expand Down Expand Up @@ -566,12 +569,20 @@ def _reweight(self, current_particles):
beta_max = 1.0
beta_min = np.copy(beta_prev)

def get_weights_and_ess(beta):
def get_weights_and_ess_(beta):
logw, _ = self.particles.compute_logw_and_logz(beta)
weights = np.exp(logw - np.max(logw))
weights /= np.sum(weights)
ess_est = 1.0 / np.sum(weights**2.0)
return weights, ess_est

def get_weights_and_ess(beta):
logw, _ = self.particles.compute_logw_and_logz(beta)
weights = np.exp(logw - np.max(logw))
weights /= np.sum(weights)
#ess_est = 1.0 / np.sum(weights**2.0)
expected_unique_all = np.sum(1-(1-weights)**len(weights))
return weights, expected_unique_all

weights_prev, ess_est_prev = get_weights_and_ess(beta_prev)
weights_max, ess_est_max = get_weights_and_ess(beta_max)
Expand Down Expand Up @@ -606,6 +617,15 @@ def get_weights_and_ess(beta):
weights = np.exp(logw - np.max(logw))
weights /= np.sum(weights)

#expected_unique = np.sum(1-(1-weights)**self.n_active)
#expected_unique_all = np.sum(1-(1-weights)**len(weights))
#print("Expected unique particles: ", expected_unique)
#print("Expected unique particles (all): ", expected_unique_all)
if self.dynamic:
n_unique_active = np.sum(1-(1-weights)**self.n_active)
if n_unique_active < self.n_active * 0.75:
self.n_ess = int(self.n_active/n_unique_active * self.n_ess)

idx, weights = trim_weights(np.arange(len(weights)), weights, ess=0.99, bins=1000)
current_particles["u"] = self.particles.get("u", index=None, flat=True)[idx]
current_particles["x"] = self.particles.get("x", index=None, flat=True)[idx]
Expand Down
17 changes: 0 additions & 17 deletions pocomc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@

SQRTEPS = math.sqrt(float(np.finfo(np.float64).eps))

def mean_minimum_distance(N : int = None,
D : int = None):
r"""
Compute the average minimum (1st neighbor) distance between N samples from a D-dimensional uniform distribution.
Parameters
----------
N : ``int``
Number of samples
D : ``int``
Number of dimensions.
Returns
-------
distance : float
Mean minimum distance.
"""
return np.exp(gammaln(D/2 + 1)/D + gammaln(1+1/D) + gammaln(N) - gammaln(N + 1/D) - 0.5 * np.log(np.pi))

def trim_weights(samples, weights, ess=0.99, bins=1000):
"""
Expand Down

0 comments on commit 6ce7727

Please sign in to comment.