From 6526ca98fec74ccc8113b1cd7daad0d6beb56b10 Mon Sep 17 00:00:00 2001 From: pedrofale Date: Tue, 23 Apr 2024 23:10:23 +0200 Subject: [PATCH] Better root log_std inits --- scatrex/ntssb/ntssb.py | 7 +++++++ scatrex/ntssb/tssb.py | 7 +++++++ scatrex/scatrex.py | 7 +++++-- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/scatrex/ntssb/ntssb.py b/scatrex/ntssb/ntssb.py index 8efcfea..b744dc2 100644 --- a/scatrex/ntssb/ntssb.py +++ b/scatrex/ntssb/ntssb.py @@ -258,6 +258,13 @@ def descend(root): descend(child) descend(self.root) + def reset_variational_kernels(self, **kwargs): + def descend(root): + root['node'].reset_variational_kernels(**kwargs) + for child in root['children']: + descend(child) + descend(self.root) + def sample_variational_distributions(self, **kwargs): def descend(root): root['node'].sample_variational_distributions(**kwargs) diff --git a/scatrex/ntssb/tssb.py b/scatrex/ntssb/tssb.py index 46faf40..f46f378 100644 --- a/scatrex/ntssb/tssb.py +++ b/scatrex/ntssb/tssb.py @@ -212,6 +212,13 @@ def descend(root): descend(child) descend(self.root) + def reset_variational_kernels(self, **kwargs): + def descend(root): + root['node'].reset_variational_kernel(**kwargs) + for child in root['children']: + descend(child) + descend(self.root) + def set_weights(self, node_weights_dict): def descend(root): root['weight'] = node_weights_dict[root['label']] diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index 1c48e55..881cfa2 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -181,6 +181,7 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01): root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init) # Initialize MC samples + self.ntssb.reset_variational_kernels(log_std=-4) self.ntssb.sample_variational_distributions(n_samples=mc_samples) self.ntssb.update_sufficient_statistics() @@ -204,6 +205,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 direction_shape = self.ntssb.root['node'].root['node'].node_hyperparams['direction_shape'] self.ntssb.set_node_hyperparams(n_factors=0) self.ntssb.root['node'].root['node'].reset_variational_noise_factors() + self.ntssb.reset_variational_kernels(log_std=0.) self.ntssb.sample_variational_distributions(n_samples=mc_samples) self.ntssb.update_sufficient_statistics() self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False) @@ -377,7 +379,7 @@ def update_anndata(self, adata): adata.layers["scatrex_mean"] = mean_mat def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees=True, allow_root_subtrees=False, root_cells=None, - batch_size=None, seed=42, + batch_size=None, seed=42, weights_concentration=1e6, n_epochs=100, mc_samples=10, step_size=0.01, n_iters=10, n_merges=10, n_swaps=10, memoized=True, dp_alpha=.1, dp_gamma=.1): """ Complete NTSSB learning procedure. @@ -388,7 +390,8 @@ def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees # Setup NTSSB self.ntssb = NTSSB(self.observed_tree, node_hyperparams=self.model_args, - seed=seed,) + seed=seed, + weights_concentration=weights_concentration) self.ntssb.add_data(np.array(adata.layers[counts_layer])) self.ntssb.make_batches(batch_size, seed) self.ntssb.reset_variational_parameters()