Skip to content

Commit

Permalink
Better root log_std inits
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 23, 2024
1 parent 786ef8c commit 6526ca9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
7 changes: 7 additions & 0 deletions scatrex/ntssb/ntssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def descend(root):
descend(child)
descend(self.root)

def reset_variational_kernels(self, **kwargs):
def descend(root):
root['node'].reset_variational_kernels(**kwargs)
for child in root['children']:
descend(child)
descend(self.root)

def sample_variational_distributions(self, **kwargs):
def descend(root):
root['node'].sample_variational_distributions(**kwargs)
Expand Down
7 changes: 7 additions & 0 deletions scatrex/ntssb/tssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def descend(root):
descend(child)
descend(self.root)

def reset_variational_kernels(self, **kwargs):
def descend(root):
root['node'].reset_variational_kernel(**kwargs)
for child in root['children']:
descend(child)
descend(self.root)

def set_weights(self, node_weights_dict):
def descend(root):
root['weight'] = node_weights_dict[root['label']]
Expand Down
7 changes: 5 additions & 2 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01):
root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init)

# Initialize MC samples
self.ntssb.reset_variational_kernels(log_std=-4)
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
self.ntssb.update_sufficient_statistics()

Expand All @@ -204,6 +205,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
direction_shape = self.ntssb.root['node'].root['node'].node_hyperparams['direction_shape']
self.ntssb.set_node_hyperparams(n_factors=0)
self.ntssb.root['node'].root['node'].reset_variational_noise_factors()
self.ntssb.reset_variational_kernels(log_std=0.)
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
self.ntssb.update_sufficient_statistics()
self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False)
Expand Down Expand Up @@ -377,7 +379,7 @@ def update_anndata(self, adata):
adata.layers["scatrex_mean"] = mean_mat

def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees=True, allow_root_subtrees=False, root_cells=None,
batch_size=None, seed=42,
batch_size=None, seed=42, weights_concentration=1e6,
n_epochs=100, mc_samples=10, step_size=0.01, n_iters=10, n_merges=10, n_swaps=10, memoized=True, dp_alpha=.1, dp_gamma=.1):
"""
Complete NTSSB learning procedure.
Expand All @@ -388,7 +390,8 @@ def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees
# Setup NTSSB
self.ntssb = NTSSB(self.observed_tree,
node_hyperparams=self.model_args,
seed=seed,)
seed=seed,
weights_concentration=weights_concentration)
self.ntssb.add_data(np.array(adata.layers[counts_layer]))
self.ntssb.make_batches(batch_size, seed)
self.ntssb.reset_variational_parameters()
Expand Down

0 comments on commit 6526ca9

Please sign in to comment.