diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py index 14c947f..134ed86 100644 --- a/scatrex/models/cna/node.py +++ b/scatrex/models/cna/node.py @@ -310,6 +310,32 @@ def reset_variational_parameters(self): self.params = [self.variational_parameters["kernel"]["state"]["mean"], jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])] + def reset_variational_kernel(self, log_std=-4, init_concentration=10): + parent = self.parent() + if parent is None and self.tssb.parent() is None: + pass + else: + parent_param = jnp.zeros((self.n_genes,)) + if parent is not None: + parent_param = parent.params[0] + + rng = np.random.default_rng(self.seed+2) + sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'], + jnp.exp(-self.node_hyperparams['inheritance_strength'] * jnp.abs(parent_param))) + rng = np.random.default_rng(self.seed+3) + if np.all(parent_param == 0): + sampled_state = rng.normal(parent_param*0.1, np.exp(log_std)) # is root node, so avoid messing with main node attachments + else: + sampled_state = jnp.clip(rng.normal(parent_param*0.1, sampled_direction), a_min=-1, a_max=1) # to explore (without numerical explosions) + + self.variational_parameters["kernel"] = { + 'direction': {'log_alpha': jnp.log(init_concentration*jnp.ones((self.n_genes,))), 'log_beta': jnp.log(init_concentration/self.node_hyperparams['direction_shape'] * jnp.ones((self.n_genes,)))}, + 'state': {'mean': jnp.array(sampled_state), 'log_std': jnp.array(rng.normal(log_std, 0.01, size=self.n_genes))} + } + self.params = [self.variational_parameters["kernel"]["state"]["mean"], + jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])] + + def reset_variational_noise_factors(self): rng = np.random.default_rng(self.seed) n_factors = self.node_hyperparams['n_factors']