forked from COMSYS/mcBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
78 lines (66 loc) · 2.16 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from argparse import ArgumentParser
from glob import glob
import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from mcBERT.utils.clustering_utils import apply_UMAP
from mcBERT.utils.patient_level_dataset import Patient_level_dataset
from mcBERT.utils.utils import get_scRNA_model, prepare_dataset, set_seeds
from omegaconf import OmegaConf
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import DataLoader
from tqdm import tqdm
set_seeds(42)
"""Script to plot the T-SNE plots and calculate cosine similarity across the patients
"""
# Config file for inference
parser = ArgumentParser()
parser.add_argument(
"--config",
type=str,
help="path to yaml config file for inference of donors",
)
args = parser.parse_args()
cfg = OmegaConf.load(args.config)
model = get_scRNA_model(cfg)
model.load_state_dict(torch.load(cfg.model.model_ckpt))
model.cuda()
model.eval()
files_all = glob(cfg.H5AD_FILES)
df = prepare_dataset(files_all, multiprocess=True)
print(
f"Using {len(df['donor_id'].unique())} patients representing {df['disease'].unique()} diseases"
)
dataset = Patient_level_dataset(
df,
cfg.HIGHLY_VAR_GENES_PATH,
inference=True,
random_cell_stratification=0,
)
dataloder = DataLoader(
dataset,
batch_size=8,
shuffle=False,
)
patient_embeddings = np.empty((len(dataset), 288))
with torch.no_grad():
# Embedding all donors using mcBERT
for i, batch in enumerate(tqdm(dataloder)):
outputs = model(batch.to("cuda"))
encoder_out = outputs.cpu()
patient_embeddings[
i * dataloder.batch_size : i * dataloder.batch_size + len(batch)
] = np.array(encoder_out)
# Calculate cosine similarity between all patients, could be used for further testing
cosine_sim = cosine_similarity(patient_embeddings)
# Create the UMAP plot
X_umap_embeddings = apply_UMAP(patient_embeddings)
fig = plt.figure(figsize=(10, 10))
sns.scatterplot(x=X_umap_embeddings[:, 0], y=X_umap_embeddings[:, 1], hue=df["disease"])
plt.title("UMAP plot of patients")
plt.legend(loc="upper left")
plt.xlabel("UMAP 1")
plt.ylabel("UMAP 2")
plt.tight_layout()
plt.savefig("umap_plot.png")