Skip to content

Commit

Permalink
Update to work with tumor-landscape
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Jul 11, 2024
1 parent ef9f89b commit d791dc4
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 107 deletions.
30 changes: 14 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
[tool.poetry]
name = "scatrex"
version = "0.2.1"
description = "Map single-cell transcriptomes to copy number evolutionary trees."
description = "Find gene expression patterns in the evolutionary history of a tumor."
license = "GNU GENERAL PUBLIC LICENSE 3"
authors = ["pedrofale <pedro.miguel.ferreira.pf@gmail.com>"]
readme = "README.md"
repository = "https://github.com/cbg-ethz/SCATrEx"

[tool.poetry.dependencies]
python = ">=3.7.1,<3.10"
pandas = "^1.3.2"
numpy = "^1.21.2"
networkx = "^2.6.2"
pygraphviz = "^1.7"
tqdm = "^4.57.0"
jax = "^0.3.20"
jaxlib = "^0.3.20"
scipy = "^1.7.3"
scikit-learn = "^0.23.2"
graphviz = "^0.14.1"
python = ">=3.9"
pandas = "^2.2.0"
numpy = "^1.26.3"
scanpy = "^1.9.8"
anndata = "^0.10.5"
scipy = "^1.12.0"
scikit-learn = "^1.4.0"
networkx = "^3.2.1"
graphviz = "^0.20.1"
gseapy = "^0.10.5"
pybiomart = "^0.2.0"
scanpy = "^1.7.0"
anndata = "^0.7.5"
tqdm = "^4.66.1"
jax = {extras = ["cpu"], version = "^0.4.20"}

[tool.poetry.dev-dependencies]
bump2version = "^1.0.1"
black = "^22.3.0"
pytest = "^6.2.4"
black = "^24.1.1"
pytest = "^8.0.0"

# [tool.poetry.extra-dependencies]
# pysankey2 =
Expand Down
4 changes: 3 additions & 1 deletion scatrex/models/cna/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,9 @@ def set_global_sample(self, sample):
def get_noise_sample(self, idx):
obs_weights = self.get_obs_weights_sample()[:,idx]
factor_weights = self.get_factor_weights_sample()
return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights)
jax.vmap(obs_weights, factor_weights)
return jnp.einsum('snk,skg->sng', obs_weights,factor_weights)
# return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights)

def get_direction_sample(self):
return self.samples[1]
Expand Down
149 changes: 119 additions & 30 deletions scatrex/models/trajectory/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def reset_variational_parameters(self):

# Sticks
self.variational_parameters["delta_1"] = 1.
self.variational_parameters["delta_2"] = 1.
self.variational_parameters["delta_2"] = (self.tssb.alpha_decay**self.depth) * self.tssb.dp_alpha
self.variational_parameters["sigma_1"] = 1.
self.variational_parameters["sigma_2"] = 1.
self.variational_parameters["sigma_2"] = self.tssb.dp_gamma

# Pivots
self.variational_parameters["q_rho"] = np.ones(len(self.tssb.children_root_nodes),)
Expand Down Expand Up @@ -224,24 +224,24 @@ def reset_variational_parameters(self):
# Kernel
radius = self.node_hyperparams['loc_mean']

if "angle" not in parent.variational_parameters["kernel"]:
if "direction" not in parent.variational_parameters["kernel"]:
mean_angle = jnp.array([parent.observed_parameters[1]])
parent_loc = jnp.array(parent.observed_parameters[0])
else:
mean_angle = parent.variational_parameters["kernel"]["angle"]["mean"]
parent_loc = parent.variational_parameters["kernel"]["loc"]["mean"]
mean_angle = parent.variational_parameters["kernel"]["direction"]["mean"]
parent_loc = parent.variational_parameters["kernel"]["state"]["mean"]

rng = np.random.default_rng(self.seed+2)
mean_angle = rng.vonmises(mean_angle, self.node_hyperparams['angle_concentration'] * self.depth)
mean_loc = parent_loc + jnp.array([np.cos(mean_angle[0])*radius, jnp.sin(mean_angle[0])*radius])
rng = np.random.default_rng(self.seed+3)
mean_loc = rng.normal(mean_loc, self.node_hyperparams['loc_variance'])
self.variational_parameters["kernel"] = {
'angle': {'mean': jnp.array(mean_angle), 'log_kappa': jnp.array([-1.])},
'loc': {'mean': jnp.array(mean_loc), 'log_std': jnp.array([-1., -1.])}
'direction': {'mean': jnp.array(mean_angle), 'log_kappa': jnp.array([-1.])},
'state': {'mean': jnp.array(mean_loc), 'log_std': jnp.array([-1., -1.])}
}
self.params = [self.variational_parameters["kernel"]["loc"]["mean"],
self.variational_parameters["kernel"]["angle"]["mean"]]
self.params = [self.variational_parameters["kernel"]["state"]["mean"],
self.variational_parameters["kernel"]["direction"]["mean"]]

def set_learned_parameters(self):
if self.parent() is None and self.tssb.parent() is None:
Expand All @@ -251,8 +251,8 @@ def set_learned_parameters(self):
elif self.parent() is None:
self.params = self.observed_parameters
else:
self.params = [self.variational_parameters["kernel"]["loc"]["mean"],
self.variational_parameters["kernel"]["angle"]["mean"]]
self.params = [self.variational_parameters["kernel"]["state"]["mean"],
self.variational_parameters["kernel"]["direction"]["mean"]]

def reset_sufficient_statistics(self, num_batches=1):
self.suff_stats = {
Expand Down Expand Up @@ -543,14 +543,14 @@ def compute_kernel_entropy(self):
return self.compute_root_entropy()

# Angle
angle_logpdf = tfd.VonMises(np.exp(self.variational_parameters['kernel']['angle']['mean']),
jnp.exp(self.variational_parameters['kernel']['angle']['log_kappa'])
angle_logpdf = tfd.VonMises(np.exp(self.variational_parameters['kernel']['direction']['mean']),
jnp.exp(self.variational_parameters['kernel']['direction']['log_kappa'])
).entropy()
angle_logpdf = jnp.sum(angle_logpdf)

# Location
loc_logpdf = tfd.Normal(self.variational_parameters['kernel']['loc']['mean'],
jnp.exp(self.variational_parameters['kernel']['loc']['log_std'])
loc_logpdf = tfd.Normal(self.variational_parameters['kernel']['state']['mean'],
jnp.exp(self.variational_parameters['kernel']['state']['log_std'])
).entropy()
loc_logpdf = jnp.sum(loc_logpdf) # Sum across features

Expand Down Expand Up @@ -595,25 +595,29 @@ def compute_globals_entropy_grad(self):

def state_sample_and_grad(self, key, n_samples):
"""Sample and take gradient of state"""
mu = self.variational_parameters['kernel']['loc']['mean']
log_std = self.variational_parameters['kernel']['loc']['log_std']
mu = self.variational_parameters['kernel']['state']['mean']
log_std = self.variational_parameters['kernel']['state']['log_std']
key, *sub_keys = jax.random.split(key, n_samples+1)
return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std)

def direction_sample_and_grad(self, key, n_samples):
"""Sample and take gradient of direction"""
mu = self.variational_parameters['kernel']['angle']['mean']
log_kappa = self.variational_parameters['kernel']['angle']['log_kappa']
mu = self.variational_parameters['kernel']['direction']['mean']
log_kappa = self.variational_parameters['kernel']['direction']['log_kappa']
key, *sub_keys = jax.random.split(key, n_samples+1)
return key, mc_sample_angle_val_and_grad(jnp.array(sub_keys), mu, log_kappa)

def state_sample_and_grad(self, key, n_samples):
"""Sample and take gradient of state"""
mu = self.variational_parameters['kernel']['loc']['mean']
log_std = self.variational_parameters['kernel']['loc']['log_std']
mu = self.variational_parameters['kernel']['state']['mean']
log_std = self.variational_parameters['kernel']['state']['log_std']
key, *sub_keys = jax.random.split(key, n_samples+1)
return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std)

def compute_direction_prior_grad(self, alpha, parent_alpha, parent_loc):
"""Gradient of logp(alpha|parent_alpha,parent_loc)"""
return self.compute_direction_prior_grad_wrt_direction(alpha, parent_alpha, parent_loc)

def compute_direction_prior_grad_wrt_direction(self, alpha, parent_alpha, parent_loc):
"""Gradient of logp(alpha|parent_alpha) wrt this alpha"""
concentration = self.get_prior_angle_concentration()
Expand All @@ -623,6 +627,14 @@ def compute_direction_prior_grad_wrt_state(self, alpha, parent_alpha, parent_loc
"""Gradient of logp(alpha|parent_alpha) wrt this alpha"""
return 0.

def compute_direction_prior_child_grad_wrt_state(self, child_direction, direction, state):
"""Gradient of logp(child_alpha|alpha) wrt this direction"""
return 0.

def compute_direction_prior_child_grad_wrt_direction(self, child_direction, direction, state):
"""Gradient of logp(child_alpha|alpha) wrt this direction"""
return self.compute_direction_prior_child_grad(child_direction, direction)

def compute_direction_prior_child_grad(self, child_alpha, alpha):
"""Gradient of logp(child_alpha|alpha) wrt this alpha"""
concentration = self.get_prior_angle_concentration(depth=self.depth+1)
Expand Down Expand Up @@ -654,14 +666,14 @@ def compute_state_prior_grad_wrt_direction(self, psi, parent_psi, alpha):

def compute_direction_entropy_grad(self):
"""Gradient of logq(alpha) wrt this alpha"""
mu = self.variational_parameters['kernel']['angle']['mean']
log_kappa = self.variational_parameters['kernel']['angle']['log_kappa']
mu = self.variational_parameters['kernel']['direction']['mean']
log_kappa = self.variational_parameters['kernel']['direction']['log_kappa']
return angle_logq_val_and_grad(mu, log_kappa)[1]

def compute_state_entropy_grad(self):
"""Gradient of logq(psi) wrt this psi"""
mu = self.variational_parameters['kernel']['loc']['mean']
log_std = self.variational_parameters['kernel']['loc']['log_std']
mu = self.variational_parameters['kernel']['state']['mean']
log_std = self.variational_parameters['kernel']['state']['log_std']
return loc_logq_val_and_grad(mu, log_std)[1]

def compute_ll_state_grad(self, x, weights, psi):
Expand Down Expand Up @@ -697,20 +709,20 @@ def compute_ll_globals_grad(self, x, idx, weights):
def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, step_size=0.001):
mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0)
angle_mean_grad = mc_grad + direction_params_entropy_grad[0]
self.variational_parameters['kernel']['angle']['mean'] += angle_mean_grad * step_size
self.variational_parameters['kernel']['direction']['mean'] += angle_mean_grad * step_size

mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0)
angle_log_kappa_grad = mc_grad + direction_params_entropy_grad[1]
self.variational_parameters['kernel']['angle']['log_kappa'] += angle_log_kappa_grad * step_size
self.variational_parameters['kernel']['direction']['log_kappa'] += angle_log_kappa_grad * step_size

def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001):
mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
loc_mean_grad = mc_grad + state_params_entropy_grad[0]
self.variational_parameters['kernel']['loc']['mean'] += loc_mean_grad * step_size
self.variational_parameters['kernel']['state']['mean'] += loc_mean_grad * step_size

mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0)
loc_log_std_grad = mc_grad + state_params_entropy_grad[1]
self.variational_parameters['kernel']['loc']['log_std'] += loc_log_std_grad * step_size
self.variational_parameters['kernel']['state']['log_std'] += loc_log_std_grad * step_size

def update_local_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001):
mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
Expand Down Expand Up @@ -769,4 +781,81 @@ def update_global_params_adaptive(self, global_params_grad, global_sample_grad,
self.variational_parameters['global']['factor_weights']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

states = (state1, state2)
return states
return states

def initialize_state_states(self):
m = jnp.zeros((self.n_genes,))
v = jnp.zeros((self.n_genes,))
state1 = (m,v)
m = jnp.zeros((self.n_genes,))
v = jnp.zeros((self.n_genes,))
state2 = (m,v)
states = (state1, state2)
return states

def update_state_adaptive(self, state_params_grad, state_sample_grad, state_params_entropy_grad, i, b1=0.9,
b2=0.999, eps=1e-8, step_size=0.001):
states = self.state_states

mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
param_grad = mc_grad + state_params_entropy_grad[0]

m, v = states[0]
m = (1 - b1) * param_grad + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
state1 = (m, v)
self.variational_parameters['kernel']['state']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0)
param_grad = mc_grad + state_params_entropy_grad[1]

m, v = states[1]
m = (1 - b1) * param_grad + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
state2 = (m, v)
self.variational_parameters['kernel']['state']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

states = (state1, state2)
self.state_states = states

def initialize_direction_states(self):
m = jnp.zeros((1,))
v = jnp.zeros((1,))
state1 = (m,v)
m = jnp.zeros((1,))
v = jnp.zeros((1,))
state2 = (m,v)
states = (state1, state2)
return states

def update_direction_adaptive(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, i, b1=0.9,
b2=0.999, eps=1e-8, step_size=0.001):
states = self.direction_states
mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0)
param_grad = mc_grad + direction_params_entropy_grad[0]

m, v = states[0]
m = (1 - b1) * param_grad + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
state1 = (m, v)
self.variational_parameters['kernel']['direction']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0)
param_grad = mc_grad + direction_params_entropy_grad[1]

m, v = states[1]
m = (1 - b1) * param_grad + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
state2 = (m, v)
self.variational_parameters['kernel']['direction']['log_kappa'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

states = (state1, state2)
self.direction_states = states
8 changes: 8 additions & 0 deletions scatrex/ntssb/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,12 @@ def get_top_obs(self, q=70, idx=None):
return top_obs

def reset_variational_state(self, **kwargs):
return

def reset_opt(self):
# For adaptive optimization
self.direction_states = self.initialize_direction_states()
self.state_states = self.initialize_state_states()

def init_new_node_kernel(self, **kwargs):
return
Loading

0 comments on commit d791dc4

Please sign in to comment.