Skip to content

Commit

Permalink
mc
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Feb 6, 2024
1 parent 4a861b8 commit 9be395c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
30 changes: 15 additions & 15 deletions inst/pydevil/pydevil/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydevil.model import model
from pydevil.guide import guide
from pydevil.utils import prepare_batch
from pydevil.utils_input import check_and_prepare_input_run_SVDE, unload_tensor
from pydevil.utils_input import check_and_prepare_input_run_SVDE, detach_tensor
from pydevil.utils_hessian import compute_hessians, compute_sandwiches

def run_SVDE(
Expand Down Expand Up @@ -115,20 +115,20 @@ def run_SVDE(
#lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) ,
# total_count= torch.clamp(overdispersion, 1e-9,1e9)).log_prob(input_matrix).sum(dim = 0)

ret['input_matrix'] = unload_tensor(ret['input_matrix'])
ret['model_matrix'] = unload_tensor(ret['model_matrix'])
ret['group_matrix'] = unload_tensor(ret['group_matrix'])
ret['sf'] = unload_tensor(ret['sf'])
ret['offset_matrix'] = unload_tensor(ret['offset_matrix'])
ret['beta_estimate_matrix'] = unload_tensor(ret['beta_estimate_matrix'])
ret['dispersion_priors'] = unload_tensor(ret['dispersion_priors'])
ret['cluster'] = unload_tensor(ret['cluster'])
input_matrix = unload_tensor(input_matrix)
model_matrix = unload_tensor(model_matrix)
overdispersion = unload_tensor(overdispersion)
coeff = unload_tensor(coeff)
loc = unload_tensor(loc)
UMI = unload_tensor(input_data['sf'])
ret['input_matrix'] = detach_tensor(ret['input_matrix'])
ret['model_matrix'] = detach_tensor(ret['model_matrix'])
ret['group_matrix'] = detach_tensor(ret['group_matrix'])
ret['sf'] = detach_tensor(ret['sf'])
ret['offset_matrix'] = detach_tensor(ret['offset_matrix'])
ret['beta_estimate_matrix'] = detach_tensor(ret['beta_estimate_matrix'])
ret['dispersion_priors'] = detach_tensor(ret['dispersion_priors'])
ret['cluster'] = detach_tensor(ret['cluster'])
input_matrix = detach_tensor(input_matrix)
model_matrix = detach_tensor(model_matrix)
overdispersion = detach_tensor(overdispersion)
coeff = detach_tensor(coeff)
loc = detach_tensor(loc)
UMI = detach_tensor(input_data['sf'])

# if cuda and torch.cuda.is_available():
# input_matrix = input_matrix.cpu().detach().numpy()
Expand Down
6 changes: 3 additions & 3 deletions inst/pydevil/pydevil/utils_hessian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from tqdm import trange
from pydevil.utils_input import unload_tensor
from pydevil.utils_input import detach_tensor

def compute_hessian(obs, model_matrix, coeff, overdispersion, size_factors):
beta = torch.tensor(coeff)
Expand Down Expand Up @@ -34,7 +34,7 @@ def compute_hessians(input_matrix, model_matrix, coeff, overdispersion, size_fac

t.set_description('Variance estimation: {:.2f} '.format(gene_idx / n_genes))
t.refresh()
solved_hessian = unload_tensor(solved_hessian)
solved_hessian = detach_tensor(solved_hessian)

if full_cov:
loc[gene_idx, :, :] = solved_hessian
Expand Down Expand Up @@ -94,7 +94,7 @@ def compute_sandwiches(input_matrix, model_matrix, coeff, overdispersion, size_f
t.set_description('Clustered variance estimation: {:.2f} '.format(gene_idx / n_genes))
t.refresh()

s = unload_tensor(s)
s = detach_tensor(s)

loc[gene_idx, :, :] = s
del s
Expand Down
6 changes: 3 additions & 3 deletions inst/pydevil/pydevil/utils_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def ensure_tensor(obj, cuda):
print("Error:", e)
return None

def unload_tensor(obj):
def detach_tensor(obj):
"""
Unload the tensor from the GPU.
"""
if isinstance(obj, torch.Tensor):
if obj.get_device() == 0:
return obj.cpu().detach().numpy()
return obj.cpu().detach()
else:
return obj.detach().numpy()
return obj.detach()
return obj

def validate_boolean(parameter, parameter_name):
Expand Down

0 comments on commit 9be395c

Please sign in to comment.