Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Dec 6, 2023
1 parent d2edf8d commit 9e11ae3
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 0 deletions.
26 changes: 26 additions & 0 deletions inst/pydevil/.virtual_documents/notebook_test/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ res = pydevil.run_SVDE(
cell_names = obs_names,
size_factors = False,
group_matrix = group_matrix,
<<<<<<< HEAD
variance="Hessian",
=======
variance="VI_Estimate",
>>>>>>> f622db1 (m)
jit_compile=True,
optimizer_name = "ClippedAdam",
lr = 0.5,
Expand All @@ -57,7 +61,29 @@ res = pydevil.run_SVDE(
)


<<<<<<< HEAD
res['params']['random_effects'][1,] - res['params']['random_effects'][1,].mean()
=======
gene_idx = 0


coeff = res['params']['beta']


obs, model_matrix, c, size_factors = X[:,gene_idx], covariates, coeff[:,gene_idx], res['params']['size_factors']


b, m = pydevil.compute_bread_and_meat(torch.tensor(obs), torch.tensor(model_matrix), torch.tensor(c), torch.tensor(res['params']['theta'][gene_idx]), torch.tensor(size_factors))


pydevil.compute_hessian(torch.tensor(obs), torch.tensor(model_matrix), torch.tensor(c), 1 / torch.tensor(res['params']['theta'][gene_idx]))


b


m
>>>>>>> f622db1 (m)


pydevil.test_posterior_null(res, [0,1,-1,0])
Expand Down
4 changes: 4 additions & 0 deletions inst/pydevil/pydevil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from pydevil.interface import run_HMC, run_SVDE
from pydevil.tests import *
from pydevil.utils import *
<<<<<<< HEAD
=======
from pydevil.utils_hessian import compute_bread_and_meat, compute_hessian
>>>>>>> f622db1 (m)
22 changes: 22 additions & 0 deletions inst/pydevil/pydevil/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from pydevil.scheduler import myOneCycleLR
from pydevil.guide import guide
from pydevil.utils import prepare_batch, compute_disperion_prior, init_beta, compute_offset_matrix, estimate_size_factors
<<<<<<< HEAD
from pydevil.utils_hessian import compute_hessians
=======
from pydevil.utils_hessian import compute_hessians, compute_sandwiches
>>>>>>> f622db1 (m)

from sklearn.metrics.pairwise import rbf_kernel

Expand Down Expand Up @@ -129,7 +133,12 @@ def run_SVDE(
else:
loc = pyro.param("beta_loc")
else:
<<<<<<< HEAD
loc = compute_hessians(input_matrix=input_matrix, model_matrix=model_matrix, coeff=coeff, overdispersion=1 / overdispersion, full_cov=full_cov)
=======
#loc = compute_hessians(input_matrix=input_matrix, model_matrix=model_matrix, coeff=coeff, overdispersion=1 / overdispersion, full_cov=full_cov)
loc = compute_sandwiches(input_matrix=input_matrix, model_matrix=model_matrix, coeff=coeff, overdispersion=overdispersion, size_factors=UMI, full_cov=full_cov)
>>>>>>> f622db1 (m)

eta = torch.exp(torch.matmul(model_matrix, coeff) + torch.unsqueeze(torch.log(UMI), 1) )
lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) ,
Expand All @@ -142,13 +151,21 @@ def run_SVDE(
coeff = coeff.cpu().detach().numpy()
loc = loc.cpu().detach().numpy()
lk = lk.cpu().detach().numpy()
<<<<<<< HEAD
=======
UMI = UMI.cpu().detach().numpy()
>>>>>>> f622db1 (m)
else:
input_matrix = input_matrix.detach().numpy()
overdispersion = overdispersion.detach().numpy()
eta = eta.detach().numpy()
coeff = coeff.detach().numpy()
loc = loc.detach().numpy()
lk = lk.detach().numpy()
<<<<<<< HEAD
=======
UMI = UMI.detach().numpy()
>>>>>>> f622db1 (m)

variance = eta + eta**2 / overdispersion
# variance = eta + eta**2 * overdispersion
Expand All @@ -160,7 +177,12 @@ def run_SVDE(
"lk" : lk,
"beta" : coeff,
"eta" : eta,
<<<<<<< HEAD
"variance" : loc
=======
"variance" : loc,
"size_factors" : UMI
>>>>>>> f622db1 (m)
},
"residuals" : (input_matrix - eta) / np.sqrt(variance),
"hyperparams" : {
Expand Down
39 changes: 39 additions & 0 deletions inst/pydevil/pydevil/utils_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,45 @@ def compute_hessians(input_matrix, model_matrix, coeff, overdispersion, full_cov

return loc

def compute_sandwiches(input_matrix, model_matrix, coeff, overdispersion, size_factors, full_cov = True):
n_samples, n_genes = input_matrix.shape
n_coefficients = model_matrix.shape[1]
loc_shape = (n_genes, n_coefficients, n_coefficients) if full_cov else (n_genes, n_coefficients)
loc = torch.zeros(loc_shape)

t = trange(n_genes)
for gene_idx in t:
sandwich = compute_bread_and_meat(obs=input_matrix[:,gene_idx], model_matrix=model_matrix, coeff=coeff[:,gene_idx], overdispersion=overdispersion[gene_idx], size_factors=size_factors)

t.set_description('Variance estimation: {:.2f} '.format(gene_idx / n_genes))
t.refresh()

if full_cov:
loc[gene_idx, :, :] = sandwich
else:
loc[gene_idx, :] = torch.diag(sandwich)

return loc


def compute_bread_and_meat(obs, model_matrix, coeff, overdispersion, size_factors):
beta = torch.tensor(coeff)
model_matrix = torch.tensor(model_matrix)
alpha = overdispersion
design_v = model_matrix.t() # Transpose for vectorized operations
yi = obs.unsqueeze(1) # Add a new axis for broadcasting
k = size_factors * torch.exp(torch.matmul(design_v.t(), beta))
gamma_sq = (1 + alpha * k) ** 2

xij = torch.einsum('ik,jk->ijk', design_v, design_v).permute(2,0,1)

bread = torch.sum((yi * alpha + 1).view(-1,1,1) * xij * k.view(-1,1,1) / gamma_sq.view(-1,1,1), dim = 0)
bread = torch.inverse(bread)

meat = torch.sum(((xij * yi.view(-1,1,1)) + (yi + 1 / alpha).view(-1,1,1) * (xij * k.view(-1,1,1) * alpha.view(-1,1,1)) / (gamma_sq.view(-1,1,1)))**2, dim = 0)

return torch.matmul(torch.matmul(bread, meat), bread)

# Example usage:
# input_matrix, model_matrix, coeff, overdispersion = ...
# result = compute_hessians(input_matrix, model_matrix, coeff, overdispersion, full_cov=True)
Binary file added tests/.DS_Store
Binary file not shown.

0 comments on commit 9e11ae3

Please sign in to comment.