Skip to content

Commit

Permalink
Add init function
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 23, 2024
1 parent 6526ca9 commit e55d918
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions scatrex/models/cna/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,32 @@ def reset_variational_parameters(self):
self.params = [self.variational_parameters["kernel"]["state"]["mean"],
jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])]

def reset_variational_kernel(self, log_std=-4, init_concentration=10):
parent = self.parent()
if parent is None and self.tssb.parent() is None:
pass
else:
parent_param = jnp.zeros((self.n_genes,))
if parent is not None:
parent_param = parent.params[0]

rng = np.random.default_rng(self.seed+2)
sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'],
jnp.exp(-self.node_hyperparams['inheritance_strength'] * jnp.abs(parent_param)))
rng = np.random.default_rng(self.seed+3)
if np.all(parent_param == 0):
sampled_state = rng.normal(parent_param*0.1, np.exp(log_std)) # is root node, so avoid messing with main node attachments
else:
sampled_state = jnp.clip(rng.normal(parent_param*0.1, sampled_direction), a_min=-1, a_max=1) # to explore (without numerical explosions)

self.variational_parameters["kernel"] = {
'direction': {'log_alpha': jnp.log(init_concentration*jnp.ones((self.n_genes,))), 'log_beta': jnp.log(init_concentration/self.node_hyperparams['direction_shape'] * jnp.ones((self.n_genes,)))},
'state': {'mean': jnp.array(sampled_state), 'log_std': jnp.array(rng.normal(log_std, 0.01, size=self.n_genes))}
}
self.params = [self.variational_parameters["kernel"]["state"]["mean"],
jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])]


def reset_variational_noise_factors(self):
rng = np.random.default_rng(self.seed)
n_factors = self.node_hyperparams['n_factors']
Expand Down

0 comments on commit e55d918

Please sign in to comment.