Skip to content

Commit

Permalink
fix bugs in GP interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Nov 25, 2024
1 parent 7ab5bee commit cbea5b3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 34 deletions.
4 changes: 2 additions & 2 deletions spateo/alignment/methods/morpho_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,8 +1508,8 @@ def _wrap_output(
"t": self.nx.to_numpy(self.t),
"optimal_R": self.nx.to_numpy(self.optimal_R),
"optimal_t": self.nx.to_numpy(self.optimal_t),
"init_R": self.nx.to_numpy(self.init_R) if self.nn_init else np.eye(self.Dim),
"init_t": self.nx.to_numpy(self.init_t) if self.nn_init else np.zeros(self.Dim),
"init_R": self.nx.to_numpy(self.init_R) if self.nn_init else np.eye(self.XAHat.shape[1]),
"init_t": self.nx.to_numpy(self.init_t) if self.nn_init else np.zeros(self.XAHat.shape[1]),
"beta": self.beta,
"Coff": self.nx.to_numpy(self.Coff),
"inducing_variables": self.nx.to_numpy(self.inducing_variables),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import torch
from tqdm import tqdm

from spateo.alignment.utils import _iteration

def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):

def gp_train(model, likelihood, train_loader, train_epochs, method, N, device, keys, verbose=True):
if torch.cuda.is_available() and device != "cpu":
model = model.cuda()
likelihood = likelihood.cuda()
Expand All @@ -24,16 +26,15 @@ def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):
lr=0.01,
)

epochs_iter = tqdm(range(train_epochs), desc="Epoch")
progress_name = f"Interpolation based on Gaussian Process Regression for {keys[0]}"
epochs_iter = _iteration(n=train_epochs, progress_name=progress_name, verbose=verbose)
for i in epochs_iter:
if method == "SVGP":
# Within each iteration, we will go over each minibatch of data
minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=True)
for x_batch, y_batch in minibatch_iter:
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
output = model(x_batch)
loss = -mll(output, y_batch)
minibatch_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
else:
Expand Down
85 changes: 58 additions & 27 deletions spateo/tdr/interpolations/interpolation_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(
inducing_num: int = 512,
normalize_spatial: bool = True,
):

self.keys = keys
# Source data
source_adata = source_adata.copy()
source_adata.X = source_adata.X if layer == "X" else source_adata.layers[layer]
Expand Down Expand Up @@ -70,7 +72,6 @@ def __init__(
self.train_y = self.train_y.squeeze()

self.nx = ot.backend.get_backend(self.train_x, self.train_y)

self.normalize_spatial = normalize_spatial
if self.normalize_spatial:
self.train_x = self.normalize_coords(self.train_x)
Expand All @@ -96,15 +97,14 @@ def __init__(

self.PCA_reduction = False
self.info_keys = {"obs_keys": obs_keys, "var_keys": var_keys}
print(self.info_keys)

# Target data
self.target_points = torch.from_numpy(target_points).float()
self.target_points = self.target_points.cpu() if self.device == "cpu" else self.target_points.cuda()

def normalize_coords(self, data: Union[np.ndarray, torch.Tensor], given_normalize: bool = False):
if not given_normalize:
self.mean_data = _unsqueeze(self.nx)(self.nx.mean(data, axis=0), 0)
self.mean_data = self.nx.mean(data, axis=0)[None, :]
data = data - self.mean_data
if not given_normalize:
self.variance = self.nx.sqrt(self.nx.sum(data**2) / data.shape[0])
Expand All @@ -114,12 +114,16 @@ def normalize_coords(self, data: Union[np.ndarray, torch.Tensor], given_normaliz
def inference(
self,
training_iter: int = 50,
verbose: bool = True,
):
self.likelihood = GaussianLikelihood()
if self.method == "SVGP":
self.GPR_model = Approx_GPModel(inducing_points=self.inducing_points)
elif self.method == "ExactGP":
self.GPR_model = Exact_GPModel(self.train_x, self.train_y, self.likelihood)
self.GPR_models = [
Exact_GPModel(self.train_x, self.train_y[:, i], self.likelihoods[i])
for i in range(self.train_y.shape[1])
]
# if to convert to GPU
if self.device != "cpu":
self.GPR_model = self.GPR_model.cuda()
Expand All @@ -134,6 +138,8 @@ def inference(
method=self.method,
N=self.N,
device=self.device,
verbose=verbose,
keys=self.keys,
)

self.GPR_model.eval()
Expand Down Expand Up @@ -181,6 +187,7 @@ def gp_interpolation(
batch_size: int = 1024,
shuffle: bool = True,
inducing_num: int = 512,
verbose: bool = True,
) -> AnnData:
"""
Learn a continuous mapping from space to gene expression pattern with the Gaussian Process method.
Expand All @@ -197,36 +204,60 @@ def gp_interpolation(
Returns:
interp_adata: an anndata object that has interpolated expression.
"""
assert keys != None, "`keys` cannot be None."
keys = [keys] if isinstance(keys, str) else keys
obs_keys = [key for key in keys if key in source_adata.obs.keys()]
var_keys = [key for key in keys if key in source_adata.var_names.tolist()]
info_keys = {"obs_keys": obs_keys, "var_keys": var_keys}
print(info_keys)
obs_data = []
var_data = []
if len(obs_keys) != 0:
for key in obs_keys:
GPR = Imputation_GPR(
source_adata=source_adata,
target_points=target_points,
keys=[key],
spatial_key=spatial_key,
layer=layer,
device=device,
method=method,
batch_size=batch_size,
shuffle=shuffle,
inducing_num=inducing_num,
)
GPR.inference(training_iter=training_iter, verbose=verbose)

# Inference
GPR = Imputation_GPR(
source_adata=source_adata,
target_points=target_points,
keys=keys,
spatial_key=spatial_key,
layer=layer,
device=device,
method=method,
batch_size=batch_size,
shuffle=shuffle,
inducing_num=inducing_num,
)
GPR.inference(training_iter=training_iter)
# Interpolation
target_info_data = GPR.interpolate(use_chunk=True)
obs_data.append(target_info_data[:, None])
if len(var_keys) != 0:
for key in var_keys:
GPR = Imputation_GPR(
source_adata=source_adata,
target_points=target_points,
keys=[key],
spatial_key=spatial_key,
layer=layer,
device=device,
method=method,
batch_size=batch_size,
shuffle=shuffle,
inducing_num=inducing_num,
)
GPR.inference(training_iter=training_iter, verbose=verbose)

# Interpolation
target_info_data = GPR.interpolate(use_chunk=True)
var_data.append(target_info_data[:, None])

# Interpolation
target_info_data = GPR.interpolate(use_chunk=True)
target_info_data = target_info_data[:, None]
# Output interpolated anndata
lm.main_info("Creating an adata object with the interpolated expression...")

obs_keys = GPR.info_keys["obs_keys"]
if len(obs_keys) != 0:
obs_data = target_info_data[:, : len(obs_keys)]
obs_data = np.concatenate(obs_data, axis=1)
obs_data = pd.DataFrame(obs_data, columns=obs_keys)

var_keys = GPR.info_keys["var_keys"]
if len(var_keys) != 0:
X = target_info_data[:, len(obs_keys) :]
X = np.concatenate(var_data, axis=1)
var_data = pd.DataFrame(index=var_keys)

interp_adata = AnnData(
Expand Down

0 comments on commit cbea5b3

Please sign in to comment.