-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetric.py
120 lines (105 loc) · 4.39 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/env/bin python
r"""
Compute cell integration metrics
"""
import argparse
import functools
import pathlib
import anndata
import numpy as np
import pandas as pd
import yaml
import scglue
import scglue.metrics
def parse_args() -> argparse.Namespace:
r"""
Parse command line arguments
"""
parser = argparse.ArgumentParser(
description="Compute integration metrics for paired samples"
)
parser.add_argument(
"-d", "--datasets", dest="datasets", type=pathlib.Path, required=True,
nargs="+", help="Path to datasets (.h5ad)"
)
parser.add_argument(
"-l", "--latents", dest="latents", type=pathlib.Path, required=True,
nargs="+", help="Path to latent embeddings (.csv)"
)
parser.add_argument(
"--cell-type", dest="cell_type", type=str, default="cell_type",
help="Column name in obs specifying cell types"
)
parser.add_argument(
"--domain", dest="domain", type=str, default="domain",
help="Column name in obs specifying domain"
)
parser.add_argument(
"-p", "--paired", dest="paired", default=False, action="store_true",
help="Whether the latent embeddings are paired"
)
parser.add_argument(
"-o", "--output", dest="output", type=pathlib.Path, required=True,
help="Path to output file (.yaml)"
)
return parser.parse_args()
def main(args: argparse.Namespace) -> None:
r"""
Main function
"""
if len(args.datasets) != len(args.latents):
raise RuntimeError("Datasets and latents should have the same number of entries!")
print("[1/3] Reading data...")
datasets = [anndata.read_h5ad(item) for item in args.datasets]
cell_types = [dataset.obs[args.cell_type].to_numpy() for dataset in datasets]
domains = [dataset.obs[args.domain].to_numpy() for dataset in datasets]
latents = [pd.read_csv(item, header=None, index_col=0).to_numpy() for item in args.latents]
unis = [dataset.X for dataset in datasets]
print(latents[0].shape)
print(domains[0].reshape(-1, 1).shape)
print(cell_types[0].reshape(-1, 1).shape)
print(unis[0].shape)
print(unis[1].shape)
print("[2/3] Computing metrics...")
masks = [np.apply_along_axis(lambda x: ~np.any(np.isnan(x)), 1, latent) for latent in latents]
for i, mask in enumerate(masks):
rm_pct = 100 * (1 - mask.sum() / mask.size)
if rm_pct:
print(f"Ignoring {rm_pct:.1f}% cells in dataset {i} due to missing values!")
combined_cell_type = np.concatenate([cell_type[mask] for cell_type, mask in zip(cell_types, masks)], axis=-1)
combined_domain = np.concatenate([domain[mask] for domain, mask in zip(domains, masks)])
combined_latent = np.concatenate([latent[mask] for latent, mask in zip(latents, masks)])
metrics = {
"mean_average_precision":
scglue.metrics.mean_average_precision(combined_latent, combined_cell_type),
"normalized_mutual_info":
scglue.metrics.normalized_mutual_info(combined_latent, combined_cell_type),
"avg_silhouette_width":
scglue.metrics.avg_silhouette_width(combined_latent, combined_cell_type),
"graph_connectivity":
scglue.metrics.graph_connectivity(combined_latent, combined_cell_type),
"seurat_alignment_score":
scglue.metrics.seurat_alignment_score(combined_latent, combined_domain, random_state=0),
"avg_silhouette_width_batch":
scglue.metrics.avg_silhouette_width_batch(combined_latent, combined_domain, combined_cell_type),
# "neighbor_conservation":
# scglue.metrics.neighbor_conservation(combined_latent, combined_uni, combined_domain)
}
if args.paired:
if len(datasets) != 2:
raise RuntimeError("Expect exactly two datasets in paired mode!")
mask = functools.reduce(np.logical_and, masks)
rm_pct = 100 * (1 - mask.sum() / mask.size)
if rm_pct:
print(f"Ignoring {rm_pct:.1f}% cells in all datasets due to missing values!")
metrics["foscttm"] = np.concatenate(
scglue.metrics.foscttm(*[latent[mask] for latent in latents])
).mean().item()
else:
metrics["foscttm"] = None
print("[3/3] Saving results...")
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w") as f:
yaml.dump(metrics, f)
if __name__ == "__main__":
main(parse_args())