Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new metrics #17

Open
wants to merge 70 commits into
base: processing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ed6cdad
retrain classifiers
chelseabright96 Apr 16, 2024
30e6156
Merge branch 'processing' into processing_minify_scanvi
chelseabright96 Apr 16, 2024
5394378
add ratio and add all query obs to result
chelseabright96 Apr 19, 2024
fc3d74d
Merge branch 'processing' of github.com:theislab/archmap_data into pr…
chelseabright96 Apr 19, 2024
baaf856
change webhook
chelseabright96 Apr 20, 2024
942e213
fix webhook call
chelseabright96 Apr 20, 2024
d112a72
change to inner join
chelseabright96 Apr 20, 2024
993eb3d
add webhook
chelseabright96 Apr 20, 2024
611f71c
move call to backend
chelseabright96 Apr 21, 2024
2da924e
fix error
chelseabright96 Apr 21, 2024
c00750b
add presence score and other metrics
chelseabright96 May 9, 2024
9c9f027
make max epochs standard
chelseabright96 May 9, 2024
1c94cb0
revert renaming
chelseabright96 May 9, 2024
e89a764
uncomment notify backend
chelseabright96 May 11, 2024
9484259
uncomment notify backend
chelseabright96 May 11, 2024
ed8b986
notify backend metrics
chelseabright96 May 12, 2024
13aa209
Merge branch 'processing_minify_scanvi' of github.com:theislab/archma…
chelseabright96 May 12, 2024
d2723c5
add webhook for metrics
chelseabright96 May 14, 2024
04bd466
add webhook_metrics parameter
chelseabright96 May 14, 2024
b5bdbc5
round metrics
chelseabright96 May 14, 2024
a260341
Merge branch 'processing_minify_scanvi' of github.com:theislab/archma…
chelseabright96 May 14, 2024
8934be3
fix spoli classifier error
chelseabright96 May 16, 2024
8f91621
add stress score
chelseabright96 May 18, 2024
470b04a
change stress score
chelseabright96 May 21, 2024
62a0684
fix cluster pres score
chelseabright96 May 23, 2024
0439ab6
add distributed training to scvi models and use speed improvement bra…
chelseabright96 Jun 1, 2024
4b7ba0b
change number of devices
chelseabright96 Jun 2, 2024
597012e
change to absolute path
chelseabright96 Jun 2, 2024
1d10036
Merge branch 'distributed_training' of github.com:theislab/archmap_da…
chelseabright96 Jun 2, 2024
5a654cf
change number of devices
chelseabright96 Jun 2, 2024
fdc909d
add X back for cxg
chelseabright96 Jun 3, 2024
88f403d
remove threading
chelseabright96 Jun 4, 2024
da89ded
update get keys function
chelseabright96 Jun 4, 2024
832ecaa
Merge branch 'distributed_training' into processing
chelseabright96 Jun 4, 2024
fd0a356
remove hardcoding of get_keys
chelseabright96 Jun 4, 2024
db96cec
fix unknown key issue
chelseabright96 Jun 4, 2024
5f5715c
change lr fetal brain
chelseabright96 Jun 6, 2024
4d6d209
uncomment
chelseabright96 Jun 6, 2024
0fe1fcb
increase lr for fetal brain
chelseabright96 Jun 9, 2024
de6d6d4
speed up metric calculations
chelseabright96 Jun 17, 2024
156803c
change max_epochs for fetal brain
chelseabright96 Jun 17, 2024
ef15501
fix output error
chelseabright96 Jun 18, 2024
f6669ae
move notify backend metrics
chelseabright96 Jun 18, 2024
4451c67
fix bug percent unknown
chelseabright96 Jun 19, 2024
7800d39
increase fetal brain epochs
chelseabright96 Jun 20, 2024
aeab4e4
add cell_type_key_list
chelseabright96 Jun 20, 2024
c0b40f7
fix bug
chelseabright96 Jun 20, 2024
f57f30d
fix bug
chelseabright96 Jun 20, 2024
d8a4027
add cell type key
chelseabright96 Jun 21, 2024
b50940f
fix duplicate bug
chelseabright96 Jun 21, 2024
82e8e31
add scanvi pred to ref
chelseabright96 Jun 22, 2024
3e7b529
add classifiers for cell type list
chelseabright96 Jun 22, 2024
8f969ff
add multi label classification and pretrain uncert
chelseabright96 Jun 25, 2024
79e6027
print model id
chelseabright96 Jun 25, 2024
afdd270
fix bug
chelseabright96 Jun 25, 2024
f3ab38d
fix bug
chelseabright96 Jun 25, 2024
a223478
uncomment
chelseabright96 Jun 25, 2024
0effefc
change batch key
chelseabright96 Jun 27, 2024
3c71ab6
fix batch column
chelseabright96 Jun 27, 2024
d519585
fix batch key issue
chelseabright96 Jun 27, 2024
ce1bece
add hnoca extended name
chelseabright96 Jun 27, 2024
90106d3
Merge branch 'pretrain' into processing
chelseabright96 Jun 28, 2024
4f783d9
change cell type key hnoca extended
chelseabright96 Jun 28, 2024
80206d0
fix cell type key bug
chelseabright96 Jul 1, 2024
205e6e0
fix bug
chelseabright96 Jul 1, 2024
f42d572
fix batch error
chelseabright96 Jul 1, 2024
d0b469f
fix hnoca issue
chelseabright96 Jul 1, 2024
64ef413
downsample pbmc for download
chelseabright96 Jul 4, 2024
f870cf9
add finer cell types hlca
chelseabright96 Jul 15, 2024
ee3cea2
add ann_level_5
chelseabright96 Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mapping/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ RUN apt-get update && \
#pip install git+https://github.com/theislab/pertpy

RUN apt-get install -y git && \
pip install git+https://github.com/theislab/scarches.git
# pip install git+https://github.com/theislab/scarches.git
pip install git+https://github.com/theislab/scarches.git@speed_improvement_merge

ENV APP_HOME /app
WORKDIR $APP_HOME
Expand Down
64 changes: 50 additions & 14 deletions mapping/classifiers/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sklearn.metrics import roc_auc_score

from sklearn.preprocessing import LabelEncoder
from scarches_api.utils.metrics import percentage_unknown

from sklearn.model_selection import train_test_split

Expand All @@ -50,7 +51,7 @@ def __init__(self, classifier_xgb=False, classifier_knn=False, classifier_native
query: adata to save labels to
query_latent: adata to read .X from for label prediction
'''
def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), classifier_path="/path/to/classifier", encoding_path="/path/to/encoding"):
def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), classifier_path="/path/to/classifier", encoding_path="/path/to/encoding", cell_type_key=None):
le = LabelEncoder()

if self.__classifier_xgb:
Expand All @@ -60,7 +61,8 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(),
xgb_model = XGBClassifier()
xgb_model.load_model(classifier_path)

query.obs["prediction_xgb"] = le.inverse_transform(xgb_model.predict(query_latent.X))
query.obs[f"{cell_type_key}_prediction_xgb"] = le.inverse_transform(xgb_model.predict(query_latent.X))
prediction_label = f"{cell_type_key}_prediction_xgb"

if self.__classifier_knn:
with open(encoding_path, "rb") as file:
Expand All @@ -69,13 +71,26 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(),
with open(classifier_path, "rb") as file:
knn_model = pickle.load(file)

query.obs["prediction_knn"] = le.inverse_transform(knn_model.predict(query_latent.X))
query.obs[f"{cell_type_key}_prediction_knn"] = le.inverse_transform(knn_model.predict(query_latent.X))
prediction_label = f"{cell_type_key}_prediction_knn"

if self.__classifier_native is not None:
if self.__model_class == sca.models.SCANVI.__class__:
query.obs["prediction_scanvi"] = self.__classifier_native.predict(query)
if self.__model_class == sca.models.scPoli.__class__:
query.obs["prediction_scpoli"] = self.__classifier_native.classify(query, scale_uncertainties=True)
if "SCANVI" in str(self.__model_class):
query.obs[f"{cell_type_key}_prediction_scanvi"] = self.__classifier_native.predict(query)
prediction_label = f"{cell_type_key}_prediction_scanvi"
else:
output=self.__classifier_native.classify(query, scale_uncertainties=True)
query.obs[f"{cell_type_key}_prediction_scpoli"] = list(output.values())[0]["preds"]
query.obs[f"{cell_type_key}_uncertainty_scpoli"] = list(output.values())[0]["uncert"]
prediction_label = f"{cell_type_key}_prediction_scpoli"

# calculate the percentage of unknown cell types (cell types with uncertainty higher than 0.5)
percent_unknown = percentage_unknown(query, cell_type_key, prediction_label)
percent_unknown = 0.0
print(percent_unknown)

return round(percent_unknown, 2)


'''
Parameters
Expand All @@ -85,24 +100,28 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(),
label_key: Cell type label
classifier_directory: Output directory for classifier and evaluation files
'''
def create_classifier(self, adata, latent_rep=False, model_path="", label_key="CellType", classifier_directory="path/to/classifier_output"):
def create_classifier(self, adata, latent_rep=False, model_path="", label_key="CellType", classifier_directory="path/to/classifier_output", validate_on_query=False):
if not os.path.exists(classifier_directory):
os.makedirs(classifier_directory, exist_ok=True)


train_data = Classifiers.__get_train_data(
self,
adata=adata,
latent_rep=latent_rep,
model_path=model_path

)

X_train, X_test, y_train, y_test = Classifiers.__split_train_data(
self,
train_data=train_data,
input_adata=adata,
label_key=label_key,
classifier_directory=classifier_directory
classifier_directory=classifier_directory,
validate_on_query=validate_on_query
)


xgbc, knnc = Classifiers.__train_classifier(
self,
Expand All @@ -126,7 +145,7 @@ def create_classifier(self, adata, latent_rep=False, model_path="", label_key="C
self,
reports,
classifier_directory=classifier_directory
)
)

Classifiers.__save_eval_metrics_csv(
self,
Expand All @@ -152,7 +171,10 @@ def __get_train_data(self, adata, latent_rep=True, model_path=None):
else:
raise Exception("Choose model type 'scVI' or 'scANVI'")

latent_rep = scanpy.AnnData(model.get_latent_representation(), adata.obs)
if "latent_rep" in adata.obsm:
latent_rep = scanpy.AnnData(adata.obsm["latent_rep"], adata.obs)
else:
latent_rep = scanpy.AnnData(model.get_latent_representation(), adata.obs)

train_data = pd.DataFrame(
data = latent_rep.X,
Expand All @@ -161,22 +183,36 @@ def __get_train_data(self, adata, latent_rep=True, model_path=None):

return train_data



'''
Parameters
----------
input_adata: adata to read the labels from
'''
def __split_train_data(self, train_data, input_adata, label_key, classifier_directory):
def __split_train_data(self, train_data, input_adata, label_key, classifier_directory, validate_on_query=False):
train_data['cell_type'] = input_adata.obs[label_key]
# train_data['type'] = input_adata.obs["type"]


#Enable if at least one class has only 1 sample -> Error in stratification for validation set
train_data = train_data.groupby('cell_type').filter(lambda x: len(x) > 1)

le = LabelEncoder()
le.fit(train_data["cell_type"])
train_data['cell_type'] = le.transform(train_data["cell_type"])
train_data['cell_type'] = le.transform(train_data["cell_type"])

if validate_on_query:
X_train = train_data.drop(columns='cell_type')
X_test = train_data[train_data["type"]=="query"]
X_train = train_data[train_data["type"]=="reference"]
y_train = X_train['cell_type']
y_test = X_test['cell_type']
X_train = X_train.drop(columns='cell_type')
X_test = X_test.drop(columns='cell_type')

X_train, X_test, y_train, y_test = train_test_split(train_data.drop(columns='cell_type'), train_data['cell_type'], test_size=0.2, random_state=42, stratify=train_data['cell_type'])
else:
X_train, X_test, y_train, y_test = train_test_split(train_data.drop(columns='cell_type'), train_data['cell_type'], test_size=0.2, random_state=42, stratify=train_data['cell_type'])

#Save label encoder
with open(classifier_directory + "/classifier_encoding.pickle", "wb") as file:
Expand Down
76 changes: 76 additions & 0 deletions mapping/classifiers/classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import scanpy as sc
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
from classifiers import Classifiers


def main(atlas, label, is_scpoli=False):

adata=sc.read(f"data/{atlas}.h5ad")

if is_scpoli:
reference_latent=sc.AnnData(adata.obsm["X_latent_qzm_scpoli"], adata.obs)

else:
reference_latent=sc.AnnData(adata.obsm["X_latent_qzm"], adata.obs)

if isinstance(label, list):
for l in label:
#create knn classifier
clf = Classifiers(False, True, None)
clf.create_classifier(reference_latent, True, "", l, f"models_new/{atlas}/{atlas}_{l}")

#create xgb classifier
clf = Classifiers(True, False, None)
clf.create_classifier(reference_latent, True, "", l, f"models_new/{atlas}/{atlas}_{l}")


else:
#create knn classifier
clf = Classifiers(False, True, None)
clf.create_classifier(reference_latent, True, "", label,f"models_new/{atlas}/{atlas}_{label}")

#create xgb classifier
clf = Classifiers(True, False, None)
clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{label}")


if __name__ == "__main__":

# atlas_dict = {"fetal_brain": "subregion_class"}

atlas_dict = {"hnoca_new":
['annot_level_1',
'annot_level_2',
'annot_level_3_rev2',
'annot_level_4_rev2',
'annot_region_rev2',
'annot_ntt_rev2',],

"hnoca_ce": 'annot_level_2_ce'}

# atlas_dict = {"hlca": "ann_level_4",
# "hlca_retrained": "ann_finest_level",
# "gb": "CellID",
# "fetal_immune":"celltype_annotation",
# "nsclc": "cell_type",
# "retina": "CellType",
# "pancreas_scpoli":"cell_type",
# "pancreas_scanvi":"cell_type",
# "pbmc":"cell_type_for_integration",
# "hypomap": "Author_CellType",
# }
# atlas_dict_scpoli = {
# "hlca_retrained": "ann_finest_level",
# "pancreas_scpoli":"cell_type",
# "pbmc":"cell_type_for_integration",
# # "hnoca":"annot_level_2",
# # "heoca": "cell_type"
# }

for atlas, label in atlas_dict.items():
main(atlas, label, is_scpoli=True)
print(f"successfully created classifier for {atlas}")
# for atlas, label in atlas_dict_scpoli.items():
# main(atlas, label, is_scpoli=True)
# print(f"successfully created classifier for {atlas}")
116 changes: 116 additions & 0 deletions mapping/classifiers/uncert_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import scanpy as sc
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scarches as sca
import pickle
import os


from sklearn.mixture import GaussianMixture

def train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key, pretrained=True):


num_clusters = adata_ref.obs[cell_type_key].nunique()
print(num_clusters)

train_emb = adata_ref.obsm[embedding_name]

#Required too much RAM
gmm = GaussianMixture(n_components=num_clusters)
gmm.fit(train_emb)

#Less RAM alternative
# kmeans = KMeans(n_clusters=num_clusters)
# kmeans.fit(train_emb)

#Save or return model
if pretrained:

with open("models_uncert/" + atlas + "/" + cell_type_key + "_mahalanobis_distance.pickle", "wb") as file:
pickle.dump(gmm, file, pickle.HIGHEST_PROTOCOL)
else:
return gmm

def train_euclidian(atlas, adata_ref, embedding_name, pretrained =True, n_neighbors = 15):

trainer = sca.utils.weighted_knn_trainer(
adata_ref,
embedding_name,
n_neighbors = n_neighbors
)

#Save model
if pretrained:
with open("models_uncert/" + atlas + "/" + "euclidian_distance.pickle", "wb") as file:
pickle.dump(trainer, file, pickle.HIGHEST_PROTOCOL)
else:
return trainer


def main(atlas, adata_ref, cell_type_key_list=None, is_scpoli = False):

if is_scpoli:
embedding_name = "X_latent_qzm_scpoli"
else:
embedding_name = "X_latent_qzm"

train_euclidian(atlas, adata_ref, embedding_name)

if isinstance(cell_type_key_list,str):
cell_type_key_list = [cell_type_key_list]

for cell_type_key in cell_type_key_list:
print(cell_type_key)
train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key)



if __name__ == "__main__":



# cell_type_key_list = ['annot_level_1',
# 'annot_level_2',
# 'annot_level_3_rev2',
# 'annot_level_4_rev2',
# 'annot_region_rev2',
# 'annot_ntt_rev2',]

atlas_dict = {"hlca": "scanvi_label",
#"gb": "CellID",
# "fetal_immune":"celltype_annotation",
# "nsclc": "cell_type",
# "retina": "CellType",
# "pancreas_scanvi":"cell_type",
# "hypomap": "Author_CellType",
}
atlas_dict_scpoli = {
# "hlca_retrained": "ann_finest_level",
# "pancreas_scpoli":"cell_type",
# "pbmc":"cell_type_for_integration",
# "hnoca":"annot_level_2",
# "heoca": "cell_type"

}

for atlas, cell_type_key_list in atlas_dict_scpoli.items():

directory = "models_uncert/" + atlas + "/"
if not os.path.exists(directory):
os.makedirs(directory)

adata_ref = sc.read(f"data/{atlas}.h5ad")
main(atlas, adata_ref, cell_type_key_list, is_scpoli = True)

for atlas, cell_type_key_list in atlas_dict.items():

directory = "models_uncert/" + atlas + "/"
if not os.path.exists(directory):
os.makedirs(directory)

adata_ref = sc.read(f"data/{atlas}.h5ad")
main(atlas, adata_ref, cell_type_key_list, is_scpoli = False)
4 changes: 3 additions & 1 deletion mapping/gcsfuse_run.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ ls -l $MNT_DIR;
echo "the pwd is now: "
pwd;

ABSOLUTE_PATH=$(pwd)


#run the api
gunicorn --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 --chdir ./scarches_api/ api:app
gunicorn --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 --chdir $ABSOLUTE_PATH/scarches_api/ api:app
Loading