From ed6cdade736b59a3348a30857046e1a1ee979132 Mon Sep 17 00:00:00 2001 From: Chelsea Date: Tue, 16 Apr 2024 18:06:04 +0200 Subject: [PATCH 01/63] retrain classifiers --- mapping/classifiers/classifiers.py | 35 ++++++++++++++++++---- mapping/classifiers/classify.py | 48 ++++++++++++++++++++++++++++++ mapping/scarches_api/models.py | 11 +++++++ 3 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 mapping/classifiers/classify.py diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index 1ee336f..5b3c0a2 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -85,15 +85,19 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), label_key: Cell type label classifier_directory: Output directory for classifier and evaluation files ''' - def create_classifier(self, adata, latent_rep=False, model_path="", label_key="CellType", classifier_directory="path/to/classifier_output"): + def create_classifier(self, adata, latent_rep=False, model_path="", label_key="CellType", classifier_directory="path/to/classifier_output", validate_on_query=False): if not os.path.exists(classifier_directory): os.makedirs(classifier_directory, exist_ok=True) + + if query is not None: + validate_on_query=True train_data = Classifiers.__get_train_data( self, adata=adata, latent_rep=latent_rep, model_path=model_path + ) X_train, X_test, y_train, y_test = Classifiers.__split_train_data( @@ -101,8 +105,10 @@ def create_classifier(self, adata, latent_rep=False, model_path="", label_key="C train_data=train_data, input_adata=adata, label_key=label_key, - classifier_directory=classifier_directory + classifier_directory=classifier_directory, + validate_on_query=validate_on_query ) + xgbc, knnc = Classifiers.__train_classifier( self, @@ -152,7 +158,10 @@ def __get_train_data(self, adata, latent_rep=True, model_path=None): else: raise Exception("Choose model type 'scVI' or 'scANVI'") - latent_rep = scanpy.AnnData(model.get_latent_representation(), adata.obs) + if "latent_rep" in adata.obsm: + latent_rep = scanpy.AnnData(adata.obsm["latent_rep"], adata.obs) + else: + latent_rep = scanpy.AnnData(model.get_latent_representation(), adata.obs) train_data = pd.DataFrame( data = latent_rep.X, @@ -161,22 +170,36 @@ def __get_train_data(self, adata, latent_rep=True, model_path=None): return train_data + + ''' Parameters ---------- input_adata: adata to read the labels from ''' - def __split_train_data(self, train_data, input_adata, label_key, classifier_directory): + def __split_train_data(self, train_data, input_adata, label_key, classifier_directory, validate_on_query=False): train_data['cell_type'] = input_adata.obs[label_key] + train_data['type'] = input_adata.obs["type"] + #Enable if at least one class has only 1 sample -> Error in stratification for validation set train_data = train_data.groupby('cell_type').filter(lambda x: len(x) > 1) le = LabelEncoder() le.fit(train_data["cell_type"]) - train_data['cell_type'] = le.transform(train_data["cell_type"]) + train_data['cell_type'] = le.transform(train_data["cell_type"]) - X_train, X_test, y_train, y_test = train_test_split(train_data.drop(columns='cell_type'), train_data['cell_type'], test_size=0.2, random_state=42, stratify=train_data['cell_type']) + if validate_on_query: + X_train = train_data.drop(columns='cell_type') + X_test = train_data[train_data["type"]=="query"] + X_train = train_data[train_data["type"]=="reference"] + y_train = X_train['cell_type'] + y_test = X_test['cell_type'] + X_train = X_train.drop(columns='cell_type') + X_test = X_test.drop(columns='cell_type') + + else: + X_train, X_test, y_train, y_test = train_test_split(train_data.drop(columns='cell_type'), train_data['cell_type'], test_size=0.2, random_state=42, stratify=train_data['cell_type']) #Save label encoder with open(classifier_directory + "/classifier_encoding.pickle", "wb") as file: diff --git a/mapping/classifiers/classify.py b/mapping/classifiers/classify.py new file mode 100644 index 0000000..ee348df --- /dev/null +++ b/mapping/classifiers/classify.py @@ -0,0 +1,48 @@ +import scanpy as sc +import scarches as sca +from scarches.dataset.trvae.data_handling import remove_sparsity +from classifiers import Classifiers + + +def main(atlas, label): + + adata=sc.read(f"data/{atlas}.h5ad") + + reference_latent=sc.AnnData(adata.obsm["latent_rep"], adata.obs) + + #create knn classifier + clf = Classifiers(False, True, None) + clf.create_classifier(reference_latent, False, "", label, f"models_new/{atlas}_{label}") + + #create xgb classifier + clf = Classifiers(True, False, None) + clf.create_classifier(reference_latent, False, "", label, f"models_new/{atlas}_{label}") + + +if __name__ == "__main__": + + atlas_dict = {"hlca": "ann_level_4", + "hlca_retrained": "ann_finest_level", + "gb": "CellID", + "fetal_immune":"celltype_annotation", + "nsclc": "cell_type", + "retina": "CellType", + "pancreas_scpoli":"cell_type", + "pancreas_scanvi":"cell_type", + "pbmc":"cell_type_for_integration", + "hypomap": "Author_CellType", + } + atlas_dict_scpoli = { + "hlca_retrained": "ann_finest_level", + "pancreas_scpoli":"cell_type", + "pbmc":"cell_type_for_integration", + # "hnoca":"annot_level_2", + # "heoca": "cell_type" + } + + for atlas, label in atlas_dict.items(): + main(atlas, label) + print(f"successfully created classifier for {atlas}") + for atlas, label in atlas_dict_scpoli.items(): + main(atlas, label, is_scpoli=True) + print(f"successfully created classifier for {atlas}") \ No newline at end of file diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index c48ed25..587590c 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -9,6 +9,7 @@ import numpy as np from scipy.sparse import csr_matrix, csc_matrix from anndata import experimental +from utils import utils import utils.parameters as parameters from utils.utils import get_from_config @@ -132,6 +133,16 @@ def _acquire_data(self): Preprocess.bool_to_categorical(self._reference_adata) Preprocess.bool_to_categorical(self._query_adata) + ref_vars = self._reference_adata.var_names + query_vars = self._query_adata.var_names + + intersection = ref_vars.intersection(query_vars) + inter_len = len(intersection) + ratio = inter_len / len(ref_vars) + + # utils.notify_backend(parameters.WEBHOOK, ratio) + + #Remove later - for testing only From 5394378882a4d18e53069ef40477f803e3b47a52 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Fri, 19 Apr 2024 14:10:26 +0200 Subject: [PATCH 02/63] add ratio and add all query obs to result --- mapping/scarches_api/models.py | 26 ++++++++++++++++++------ mapping/scarches_api/utils/parameters.py | 1 + 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 587590c..b3571ce 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -140,7 +140,7 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - # utils.notify_backend(parameters.WEBHOOK, ratio) + utils.notify_backend(parameters.WEBHOOK_RATIO, {"ratio":ratio}) @@ -202,14 +202,22 @@ def _concat_data(self): print("concatenating in memory") #self._query_adata.X = csr_matrix(self._query_adata.X.copy()) - self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey') + self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey', join="outer") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var #self._compute_latent_representation(explicit_representation=self._combined_adata) + query_obs=set(self._query_adata.obs.columns) + ref_obs=set(self._reference_adata.obs.columns) + inter = ref_obs.intersection(query_obs) + new_columns = query_obs.union(inter) + self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + del self._query_adata del self._reference_adata gc.collect() + + return print("concatenating on disk") @@ -229,19 +237,25 @@ def _concat_data(self): self._reference_adata.write_h5ad(temp_reference.name) self._query_adata.write_h5ad(temp_query.name) + #Concatenate on disk to save memory + experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name, join="outer") + + query_obs=set(self._query_adata.obs.columns) + ref_obs=set(self._reference_adata.obs.columns) + inter = ref_obs.intersection(query_obs) + new_columns = query_obs.union(inter) + del self._reference_adata del self._query_adata gc.collect() - #Concatenate on disk to save memory - experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name) - - print("successfully concatenated") #Read concatenated data back in self._combined_adata = scanpy.read_h5ad(temp_combined.name) + self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + print("read concatenated file") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index b6e6f10..03e96c4 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -66,6 +66,7 @@ SCANVI_PREDICT_CELLTYPES = 'predict' # sets the key for the webhook to call after the computation WEBHOOK = 'webhook' +WEBHOOK_RATIO = 'webhook_ratio' # sets the path/s3 key of the pretrained model PRETRAINED_MODEL_PATH = 'model_path' # set scpoli model attr key From baaf856d0cf4bfcf9023abb70ff96ac61eff08cf Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 20 Apr 2024 07:44:44 +0200 Subject: [PATCH 03/63] change webhook --- mapping/scarches_api/models.py | 3 ++- mapping/scarches_api/utils/parameters.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index b3571ce..381b115 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -140,7 +140,8 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - utils.notify_backend(parameters.WEBHOOK_RATIO, {"ratio":ratio}) + # utils.notify_backend(parameters.WEBHOOK_RATIO, {"ratio":ratio}) + utils.notify_backend(parameters.WEBHOOK, {"ratio":ratio}) diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index 03e96c4..617a031 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -66,7 +66,7 @@ SCANVI_PREDICT_CELLTYPES = 'predict' # sets the key for the webhook to call after the computation WEBHOOK = 'webhook' -WEBHOOK_RATIO = 'webhook_ratio' +# WEBHOOK_RATIO = 'webhook_ratio' # sets the path/s3 key of the pretrained model PRETRAINED_MODEL_PATH = 'model_path' # set scpoli model attr key From 942e213dae9bbd7737fd5cea7a075e22f25178a4 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 20 Apr 2024 08:42:15 +0200 Subject: [PATCH 04/63] fix webhook call --- mapping/scarches_api/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 381b115..9b802b1 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -11,7 +11,7 @@ from anndata import experimental from utils import utils -import utils.parameters as parameters +from utils import parameters from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 from utils.utils import read_h5ad_file_from_s3 @@ -37,6 +37,7 @@ def __init__(self, configuration) -> None: self._scpoli_var_names = get_from_config(configuration=configuration, key=parameters.SCPOLI_VAR_NAMES) self._reference_adata_path = get_from_config(configuration=configuration, key=parameters.REFERENCE_DATA_PATH) self._query_adata_path = get_from_config(configuration=configuration, key=parameters.QUERY_DATA_PATH) + self._webhook = utils.get_from_config(configuration, parameters.WEBHOOK) # self._use_gpu = get_from_config(configuration=configuration, key=parameters.USE_GPU) #Has to be empty for the load_query_data function to work properly (looking for "model.pt") @@ -141,7 +142,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) # utils.notify_backend(parameters.WEBHOOK_RATIO, {"ratio":ratio}) - utils.notify_backend(parameters.WEBHOOK, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) From d112a72a99a431f4b72708076de95292043cce60 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 20 Apr 2024 09:13:54 +0200 Subject: [PATCH 05/63] change to inner join --- mapping/scarches_api/models.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 9b802b1..f0169cf 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -204,15 +204,15 @@ def _concat_data(self): print("concatenating in memory") #self._query_adata.X = csr_matrix(self._query_adata.X.copy()) - self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey', join="outer") + self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey', join="inner") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var #self._compute_latent_representation(explicit_representation=self._combined_adata) - query_obs=set(self._query_adata.obs.columns) - ref_obs=set(self._reference_adata.obs.columns) - inter = ref_obs.intersection(query_obs) - new_columns = query_obs.union(inter) - self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + # query_obs=set(self._query_adata.obs.columns) + # ref_obs=set(self._reference_adata.obs.columns) + # inter = ref_obs.intersection(query_obs) + # new_columns = query_obs.union(inter) + # self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] del self._query_adata del self._reference_adata @@ -240,12 +240,12 @@ def _concat_data(self): self._query_adata.write_h5ad(temp_query.name) #Concatenate on disk to save memory - experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name, join="outer") + experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name, join="inner") - query_obs=set(self._query_adata.obs.columns) - ref_obs=set(self._reference_adata.obs.columns) - inter = ref_obs.intersection(query_obs) - new_columns = query_obs.union(inter) + # query_obs=set(self._query_adata.obs.columns) + # ref_obs=set(self._reference_adata.obs.columns) + # inter = ref_obs.intersection(query_obs) + # new_columns = query_obs.union(inter) del self._reference_adata del self._query_adata @@ -256,7 +256,7 @@ def _concat_data(self): #Read concatenated data back in self._combined_adata = scanpy.read_h5ad(temp_combined.name) - self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + # self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] print("read concatenated file") From 993eb3d6c092afc4533bf9bb0369e9595d3c03a3 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 20 Apr 2024 17:51:19 +0200 Subject: [PATCH 06/63] add webhook --- mapping/scarches_api/models.py | 6 +++--- mapping/scarches_api/utils/parameters.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index f0169cf..5c6f658 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -37,7 +37,7 @@ def __init__(self, configuration) -> None: self._scpoli_var_names = get_from_config(configuration=configuration, key=parameters.SCPOLI_VAR_NAMES) self._reference_adata_path = get_from_config(configuration=configuration, key=parameters.REFERENCE_DATA_PATH) self._query_adata_path = get_from_config(configuration=configuration, key=parameters.QUERY_DATA_PATH) - self._webhook = utils.get_from_config(configuration, parameters.WEBHOOK) + self._webhook = utils.get_from_config(configuration, parameters.WEBHOOK_RATIO) # self._use_gpu = get_from_config(configuration=configuration, key=parameters.USE_GPU) #Has to be empty for the load_query_data function to work properly (looking for "model.pt") @@ -204,7 +204,7 @@ def _concat_data(self): print("concatenating in memory") #self._query_adata.X = csr_matrix(self._query_adata.X.copy()) - self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey', join="inner") + self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey') self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var #self._compute_latent_representation(explicit_representation=self._combined_adata) @@ -240,7 +240,7 @@ def _concat_data(self): self._query_adata.write_h5ad(temp_query.name) #Concatenate on disk to save memory - experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name, join="inner") + experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name) # query_obs=set(self._query_adata.obs.columns) # ref_obs=set(self._reference_adata.obs.columns) diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index 617a031..03e96c4 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -66,7 +66,7 @@ SCANVI_PREDICT_CELLTYPES = 'predict' # sets the key for the webhook to call after the computation WEBHOOK = 'webhook' -# WEBHOOK_RATIO = 'webhook_ratio' +WEBHOOK_RATIO = 'webhook_ratio' # sets the path/s3 key of the pretrained model PRETRAINED_MODEL_PATH = 'model_path' # set scpoli model attr key From 611f71c88eddc11513eb6b85a6bb512e4bb8844e Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 21 Apr 2024 14:58:54 +0200 Subject: [PATCH 07/63] move call to backend --- mapping/scarches_api/init.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mapping/scarches_api/init.py b/mapping/scarches_api/init.py index 53c1763..b583ed3 100644 --- a/mapping/scarches_api/init.py +++ b/mapping/scarches_api/init.py @@ -139,6 +139,12 @@ def query(user_config): mapping = ScPoli(configuration=configuration) mapping.run() + utils.notify_backend(mapping._webhook, {"ratio":mapping.ratio}) + + # atlas_name = utils.get_from_config(configuration, parameters.ATLAS) + + # sc.AnnData(mapping._combined_adata.obsm["latent_rep"], mapping._combined_adata.obs).write(f"results/{atlas_name}.h5ad") + if get_from_config(configuration, parameters.WEBHOOK) is not None and len( get_from_config(configuration, parameters.WEBHOOK)) > 0: utils.notify_backend(get_from_config(configuration, parameters.WEBHOOK), configuration) From 2da924e88e72edfce1ba0448180d246179f4b5e3 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 21 Apr 2024 16:47:29 +0200 Subject: [PATCH 08/63] fix error --- mapping/scarches_api/init.py | 1 - mapping/scarches_api/models.py | 1 - 2 files changed, 2 deletions(-) diff --git a/mapping/scarches_api/init.py b/mapping/scarches_api/init.py index b583ed3..ba68475 100644 --- a/mapping/scarches_api/init.py +++ b/mapping/scarches_api/init.py @@ -139,7 +139,6 @@ def query(user_config): mapping = ScPoli(configuration=configuration) mapping.run() - utils.notify_backend(mapping._webhook, {"ratio":mapping.ratio}) # atlas_name = utils.get_from_config(configuration, parameters.ATLAS) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 5c6f658..8926b5f 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -141,7 +141,6 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - # utils.notify_backend(parameters.WEBHOOK_RATIO, {"ratio":ratio}) utils.notify_backend(self._webhook, {"ratio":ratio}) From c00750b2b477e25670fc497506bf53f36ef7136d Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 9 May 2024 07:59:50 +0200 Subject: [PATCH 09/63] add presence score and other metrics --- mapping/process/processing.py | 7 +- mapping/scarches_api/models.py | 56 +++- mapping/scarches_api/utils/metrics.py | 375 ++++++++++++++++++++++++++ mapping/scarches_api/utils/utils.py | 2 + 4 files changed, 426 insertions(+), 14 deletions(-) create mode 100644 mapping/scarches_api/utils/metrics.py diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 85b29c7..7171135 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -337,6 +337,10 @@ def get_keys(atlas, target_adata): elif atlas == "heoca": cell_type_key = "cell_type" batch_key = "sample_id" + elif atlas == "fetal_brain": + cell_type_key = "subregion_class" + batch_key = "batch" + @@ -640,8 +644,7 @@ def __prepare_output(latent_adata: sc.AnnData, combined_adata: sc.AnnData, confi sc.tl.leiden(combined_adata) print("leiden") sc.tl.umap(combined_adata) - print("umap") - + print("umap") def __output_csv(obs_to_drop: list, latent_adata: sc.AnnData, combined_adata: sc.AnnData, config, predict_scanvi): Postprocess.__prepare_output(latent_adata, combined_adata, config) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 8926b5f..f9d466b 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -7,11 +7,13 @@ import torch import gc import numpy as np +import scipy from scipy.sparse import csr_matrix, csc_matrix from anndata import experimental from utils import utils from utils import parameters +from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 from utils.utils import read_h5ad_file_from_s3 @@ -108,6 +110,22 @@ def _map_query(self): self._query_adata.X = all_zeros.copy() + # Calculate presence score + + presence_score=estimate_presence_score( + self._reference_adata, + self._query_adata) + + self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + + clust_pres_score=cluster_preservation_score(self._query_adata) + print(f"clust_pres_score: {clust_pres_score}") + + query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + print(f"query_with_anchor: {query_with_anchor}") + + # utils.notify_backend(self._webhook, {"clust_pres_score":clust_pres_score, "query_with_anchor":query_with_anchor}) + def _acquire_data(self): #Download query and reference from GCP @@ -141,10 +159,11 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - utils.notify_backend(self._webhook, {"ratio":ratio}) - + # utils.notify_backend(self._webhook, {"ratio":ratio}) - + self._query_adata.obs_names_make_unique() + self._query_adata.var_names_make_unique() + #Remove later - for testing only # self._reference_adata = scanpy.pp.subsample(self._reference_adata, 0.1, copy=True) @@ -198,20 +217,25 @@ def _concat_data(self): self._reference_adata.obs["query"]=["0"]*self._reference_adata.n_obs #Added because concat_on_disk only allows csr concat - if self._query_adata.X.format == "csc" or self._reference_adata.X.format == "csc": + + + if scipy.sparse.issparse(self._query_adata.X) and (self._query_adata.X.format == "csc" or self._reference_adata.X.format == "csc"): print("concatenating in memory") #self._query_adata.X = csr_matrix(self._query_adata.X.copy()) - self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey') - self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - #self._compute_latent_representation(explicit_representation=self._combined_adata) + self._combined_adata = self._reference_adata.concatenate(self._query_adata, batch_key='bkey',join="outer") - # query_obs=set(self._query_adata.obs.columns) - # ref_obs=set(self._reference_adata.obs.columns) - # inter = ref_obs.intersection(query_obs) - # new_columns = query_obs.union(inter) - # self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + query_obs=set(self._query_adata.obs.columns) + ref_obs=set(self._reference_adata.obs.columns) + inter = ref_obs.intersection(query_obs) + new_columns = query_obs.union(inter) + self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] + + self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var + self._combined_adata.obs["presence_score"] = self.presence_score + + del self._query_adata del self._reference_adata @@ -260,6 +284,7 @@ def _concat_data(self): print("read concatenated file") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var + self._combined_adata.obs["presence_score"] = self.presence_score print("added latent rep to adata") @@ -487,6 +512,13 @@ def _map_query(self): self._query_adata.X = all_zeros.copy() + presence_score = estimate_presence_score( + self._reference_adata, + self._query_adata) + + self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + + def _compute_latent_representation(self, explicit_representation, mean=False): explicit_representation.obsm["latent_rep"] = self._model.get_latent(explicit_representation, mean=mean) diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py new file mode 100644 index 0000000..1e4c267 --- /dev/null +++ b/mapping/scarches_api/utils/metrics.py @@ -0,0 +1,375 @@ +import scanpy as sc +import numpy as np +import pandas as pd +import anndata as ad +import torch +from pynndescent import NNDescent + +from scipy import sparse +from typing import Optional, Union, Mapping, Literal +import warnings +import sys +import os +import importlib.util +import argparse +from scipy.sparse import csr_matrix + +warnings.filterwarnings("ignore") + + + +def nn2adj(nn, n1=None, n2=None): + if n1 is None: + n1 = nn[1].shape[0] + if n2 is None: + n2 = np.max(nn[1].flatten()) + + df = pd.DataFrame( + { + "i": np.repeat(range(nn[0].shape[0]), nn[0].shape[1]), + "j": nn[0].flatten(), + "x": nn[1].flatten(), + } + ) + adj = sparse.csr_matrix( + (np.repeat(1, df.shape[0]), (df["i"], df["j"])), shape=(n1, n2) + ) + return adj + + +def build_nn( + ref, + query=None, + k=100, + weight: Literal["unweighted", "dist", "gaussian_kernel"] = "unweighted", + sigma=None, +): + if query is None: + query = ref + + if torch.cuda.is_available() and importlib.util.find_spec("cuml"): + print("GPU detected and cuml installed. Use cuML for neighborhood estimation.") + from cuml.neighbors import NearestNeighbors + + model = NearestNeighbors(n_neighbors=k) + model.fit(ref) + knn = model.kneighbors(query) + + else: + print( + "Failed calling cuML. Falling back to neighborhood estimation using CPU with pynndescent." + ) + index = NNDescent(ref, n_neighbors=k) + knn = index.query(query, k=k) + + adj = nn2adj(knn, n1=query.shape[0], n2=ref.shape[0]) + return adj + + +def build_mutual_nn(dat1, dat2=None, k1=15, k2=None): + if dat2 is None: + dat2 = dat1 + if k2 is None: + k2 = k1 + + adj_12 = build_nn(dat1, dat2, k=k2) + adj_21 = build_nn(dat2, dat1, k=k1) + + adj_mnn = adj_12.multiply(adj_21.T) + return adj_mnn + +def percent_query_with_anchor(ref_adata, query_adata): + ref = ref_adata.obsm["latent_rep"] + query = query_adata.obsm["latent_rep"] + adj_mnn=build_mutual_nn(ref,query) + has_anchor=adj_mnn.sum(0)>0 #all query cells that have an anchor (output dim: no query cells) + percentage = (has_anchor.sum()/adj_mnn.shape[1])*100 + return percentage + + +def get_transition_prob_mat(dat, k=50, symm=True): + adj = build_nn(dat, k=k) + if symm: + adj = adj + adj.transpose() + prob = sparse.diags(1 / np.array(adj.sum(1)).flatten()) @ adj.transpose() + return prob + + +def random_walk_with_restart(init, transition_prob, alpha=0.5, num_rounds=100): + init = np.array(init).flatten() + heat = init[:, None] + for i in range(num_rounds): + heat = init[:, None] * alpha + (1 - alpha) * ( + transition_prob.transpose() @ heat + ) + return heat + + +def get_wknn( + ref, + query, + ref2=None, + k: int = 100, + query2ref: bool = True, + ref2query: bool = True, + weighting_scheme: Literal[ + "n", "top_n", "jaccard", "jaccard_square", "gaussian", "dist" + ] = "jaccard_square", + top_n: Optional[int] = None, + return_adjs: bool = False, +): + """ + Compute the weighted k-nearest neighbors graph between the reference and query datasets + + Parameters + ---------- + ref : np.ndarray + The reference representation to build ref-query neighbor graph + query : np.ndarray + The query representation to build ref-query neighbor graph + ref2 : np.ndarray + The reference representation to build ref-ref neighbor graph + k : int + Number of neighbors per cell + query2ref : bool + Consider query-to-ref neighbors + ref2query : bool + Consider ref-to-query neighbors + weighting_scheme : str + How to weight edges in the ref-query neighbor graph + top_n : int + The number of top neighbors to consider + return_adjs : bool + Whether to return the adjacency matrices of ref-query, query-ref, and ref-ref for weighting + """ + adj_q2r = build_nn(ref=ref, query=query, k=k) + + adj_r2q = None + if ref2query: + adj_r2q = build_nn(ref=query, query=ref, k=k) + + if query2ref and not ref2query: + adj_knn = adj_q2r.T + elif ref2query and not query2ref: + adj_knn = adj_r2q + elif ref2query and query2ref: + adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 + else: + warnings.warn( + "At least one of query2ref and ref2query should be True. Reset to default with both being True." + ) + adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 + + if ref2 is None: + ref2 = ref + adj_ref = build_nn(ref=ref2, k=k) + + + num_shared_neighbors = adj_q2r @ adj_ref + num_shared_neighbors_nn = num_shared_neighbors.multiply(adj_knn.T) # only keep weights if q and r are both nearest neigbours of eachother + + del num_shared_neighbors + + wknn = num_shared_neighbors_nn.copy() + if weighting_scheme == "top_n": + if top_n is None: + top_n = k // 4 if k > 4 else 1 + wknn = (wknn > top_n) * 1 + elif weighting_scheme == "jaccard": + wknn.data = wknn.data / (k + k - wknn.data) + elif weighting_scheme == "jaccard_square": + wknn.data = (wknn.data / (k + k - wknn.data)) ** 2 + + if return_adjs: + adjs = {"q2r": adj_q2r, "r2q": adj_r2q, "knn": adj_knn, "r2r": adj_ref} + return (wknn, adjs) + else: + return wknn + +#estimate presence of reference cell in query based on how many neighbours a query and ref cell shares +# uncertainty: How far is a query cell from the nearest ref cells. +def estimate_presence_score( + ref_adata, + query_adata, + wknn=None, + use_rep_ref_wknn="latent_rep", + use_rep_query_wknn="latent_rep", + k_wknn=15, + query2ref_wknn=True, + ref2query_wknn=False, + weighting_scheme_wknn="jaccard_square", + ref_trans_prop=None, + use_rep_ref_trans_prop=None, + k_ref_trans_prop=50, + symm_ref_trans_prop=True, + split_by=None, + do_random_walk=True, + alpha_random_walk=0.1, + num_rounds_random_walk=100, + log=True, +): + if wknn is None: + ref = ref_adata.obsm[use_rep_ref_wknn] + query = query_adata.obsm[use_rep_query_wknn] + wknn = get_wknn( + ref=ref, + query=query, + k=k_wknn, + query2ref=query2ref_wknn, + ref2query=ref2query_wknn, + weighting_scheme=weighting_scheme_wknn, + ) + + if ref_trans_prop is None and do_random_walk: + if use_rep_ref_trans_prop is None: + use_rep_ref_trans_prop = use_rep_ref_wknn + ref = ref_adata.obsm[use_rep_ref_trans_prop] + ref_trans_prop = get_transition_prob_mat(ref, k=k_ref_trans_prop) + + if split_by and split_by in query_adata.obs.columns: + presence_split = [ + np.array(wknn[query_adata.obs[split_by] == x, :].sum(axis=0)).flatten() + for x in query_adata.obs[split_by].unique() + ] + else: + presence_split = [np.array(wknn.sum(axis=0)).flatten()] + if do_random_walk: + presence_split_sm = [ + random_walk_with_restart( + init=x, + transition_prob=ref_trans_prop, + alpha=alpha_random_walk, + num_rounds=num_rounds_random_walk, + ) + for x in presence_split + ] + else: + presence_split_sm = [x[:, None] for x in presence_split] + + columns = ( + query_adata.obs[split_by].unique() + if split_by and split_by in query_adata.obs.columns + else ["query"] + ) + if len(columns) > 1: + df_presence = pd.DataFrame( + np.concatenate(presence_split_sm, axis=1), + columns=columns, + index=ref_adata.obs_names, + ) + else: + df_presence = pd.DataFrame({columns[0]: presence_split_sm[0].flatten()}).set_index( + ref_adata.obs_names + ) + + if log: + df_presence = df_presence.apply(lambda x: np.log1p(x), axis=0) + df_presence_norm = df_presence.apply( + lambda x: np.clip(x, np.percentile(x, 1), np.percentile(x, 99)), axis=0 + ).apply(lambda x: (x - np.min(x)) / (np.max(x) - np.min(x)), axis=0) + max_presence = df_presence_norm.max(1) + + + return { + "max": max_presence, + "per_group": df_presence_norm, + "ref_trans_prop": ref_trans_prop, + } + + +# def transfer_labels(ref_adata, query_adata, wknn, label_key="celltype"): +# scores = pd.DataFrame( +# wknn @ pd.get_dummies(ref_adata.obs[label_key]), +# columns=pd.get_dummies(ref_adata.obs[label_key]).columns, +# index=query_adata.obs_names, +# ) +# scores["best_label"] = scores.idxmax(1) +# scores["best_score"] = scores.max(1) +# return scores + + +# cluster preservation score + +import numpy as np +import scanpy as sc +from sklearn.neighbors import NearestNeighbors +from scipy.special import rel_entr +import pandas as pd + +def cluster_preservation_score(adata, ds_amount=5000, type='standard'): + """ + Calculate the cluster preservation score for an AnnData object after preprocessing. + Parameters: + - adata: AnnData object containing single-cell data. + - ds_amount: Maximum number of cells to include. + Returns: + - score: Cluster preservation score. + """ + dims = min(50, adata.uns.get('Azimuth_map_ndims', 50)) + + if type == 'standard': + # Following the standard preprocessing workflow + # sc.pp.scale(adata, zero_center=True) + if adata.n_obs > ds_amount: + adata = adata[np.random.choice(adata.obs_names, ds_amount, replace=False), :] + sc.tl.pca(adata, svd_solver='arpack', n_comps=dims) + sc.pp.neighbors(adata, n_neighbors=15, n_pcs=dims, method='umap', metric='euclidean') + sc.tl.leiden(adata, resolution=0.6) + # sc.pp.neighbors(adata, use_rep='latent_rep', key_added="integrated_neighbors") + + elif type == 'bridge': + # Bridge-specific preprocessing + sc.pp.scale(adata, zero_center=True) + if adata.n_obs > ds_amount: + adata = adata[np.random.choice(adata.obs_names, ds_amount, replace=False), :] + sc.pp.svd(adata, n_comps=dims) + sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_svd', method='umap', metric='euclidean') + sc.tl.leiden(adata, resolution=0.6) + # sc.pp.neighbors(adata, use_rep='latent_rep', key_added="integrated_neighbors") + + else: + print("Incorrect type: Must be either 'standard' or 'bridge'") + return None + + # Entropy calculations for neighborhood preservation + nn_orig = NearestNeighbors(n_neighbors=15) + nn_orig.fit(adata.obsm['X_pca'] if type == 'standard' else adata.obsm['X_svd']) + _, orig_indices = nn_orig.kneighbors() + + nn_integrated = NearestNeighbors(n_neighbors=15) + nn_integrated.fit(adata.obsm['latent_rep']) + _, integrated_indices = nn_integrated.kneighbors() + + # Calculate entropy for each set of neighbors + def entropy_of_labels(indices): + labels = adata.obs['leiden'][indices].to_numpy() + _, counts = np.unique(labels, return_counts=True) + return rel_entr(counts, np.full_like(counts, fill_value=1/len(counts))).sum() + + orig_ent = np.array([entropy_of_labels(idx) for idx in orig_indices]) + integrated_ent = np.array([entropy_of_labels(idx) for idx in integrated_indices]) + + # Calculate the cluster preservation statistic + ids = adata.obs['leiden'].to_numpy() + orig_means = pd.Series(orig_ent).groupby(ids).mean() + integrated_means = pd.Series(integrated_ent).groupby(ids).mean() + stat = np.median(orig_means - integrated_means) + + if stat <= 0: + return 5.00 + else: + stat = -1 * np.log2(stat) + stat = np.clip(stat, 0.00, 5.00) + return stat + + def percentage_unknown(query, prediction_label, uncertainty_threshold=0.5): + query.obs[f"{prediction_label}_filtered_by_uncert>0.5"] = query.obs[ + prediction_label + ].mask( + query.obs["uncertainty_mahalanobis"] > uncertainty_threshold, + "Unknown", + ) + + number_unknown = (query.obs["uncertainty_mahalanobis"] > uncertainty_threshold).sum() + + return number_unknown/len(query)*100 diff --git a/mapping/scarches_api/utils/utils.py b/mapping/scarches_api/utils/utils.py index fe443f2..c48d7bf 100644 --- a/mapping/scarches_api/utils/utils.py +++ b/mapping/scarches_api/utils/utils.py @@ -588,6 +588,8 @@ def translate_atlas_to_directory(configuration): return "hnoca" elif atlas == "HEOCA": return "heoca" + elif atlas == "Fetal Brain": + return "fetal_brain" def set_keys(configuration): From 9c9f0279c71d0e797ac99445ba90cc3d48c21e46 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 9 May 2024 10:16:20 +0200 Subject: [PATCH 10/63] make max epochs standard --- mapping/scarches_api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index f9d466b..575189d 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -393,7 +393,7 @@ def _map_query(self): ) self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCVI_QUERY_MAX_EPOCHS) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) super()._map_query() @@ -433,7 +433,7 @@ def _map_query(self, supervised=False): model._labeled_indices = [] self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCANVI_MAX_EPOCHS_QUERY) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) super()._map_query() @@ -463,7 +463,7 @@ def _map_query(self): self._query_adata = model.adata self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCPOLI_MAX_EPOCHS) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) model.train( n_epochs=self._max_epochs, From 1c94cb0c44f544a6cab503ff43f04fd9a22d871c Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Thu, 9 May 2024 12:57:06 +0200 Subject: [PATCH 11/63] revert renaming --- mapping/scarches_api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 575189d..f9d466b 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -393,7 +393,7 @@ def _map_query(self): ) self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCVI_QUERY_MAX_EPOCHS) super()._map_query() @@ -433,7 +433,7 @@ def _map_query(self, supervised=False): model._labeled_indices = [] self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCANVI_MAX_EPOCHS_QUERY) super()._map_query() @@ -463,7 +463,7 @@ def _map_query(self): self._query_adata = model.adata self._model = model - self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.MAX_EPOCHS_QUERY) + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCPOLI_MAX_EPOCHS) model.train( n_epochs=self._max_epochs, From e89a7644f442431818190d159a21374a54fd6cd1 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 11 May 2024 16:40:08 +0200 Subject: [PATCH 12/63] uncomment notify backend --- mapping/scarches_api/models.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index f9d466b..8e1510a 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -159,7 +159,7 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - # utils.notify_backend(self._webhook, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) self._query_adata.obs_names_make_unique() self._query_adata.var_names_make_unique() @@ -217,8 +217,6 @@ def _concat_data(self): self._reference_adata.obs["query"]=["0"]*self._reference_adata.n_obs #Added because concat_on_disk only allows csr concat - - if scipy.sparse.issparse(self._query_adata.X) and (self._query_adata.X.format == "csc" or self._reference_adata.X.format == "csc"): print("concatenating in memory") @@ -433,6 +431,7 @@ def _map_query(self, supervised=False): model._labeled_indices = [] self._model = model + self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCANVI_MAX_EPOCHS_QUERY) super()._map_query() From 9484259596cd3e52893cbcc2f05857693c0138d4 Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Sat, 11 May 2024 16:42:57 +0200 Subject: [PATCH 13/63] uncomment notify backend --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index f9d466b..6defc15 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -159,7 +159,7 @@ def _acquire_data(self): inter_len = len(intersection) ratio = inter_len / len(ref_vars) - # utils.notify_backend(self._webhook, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) self._query_adata.obs_names_make_unique() self._query_adata.var_names_make_unique() From ed8b986ceded569fc2dc66a9c6730f3db11d3850 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 12 May 2024 18:20:47 +0200 Subject: [PATCH 14/63] notify backend metrics --- mapping/classifiers/classifiers.py | 18 ++++++++++++++---- mapping/scarches_api/models.py | 15 ++++++++------- mapping/scarches_api/utils/metrics.py | 22 +++++++++++----------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index 5b3c0a2..9ae7f83 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -31,6 +31,7 @@ from sklearn.metrics import roc_auc_score from sklearn.preprocessing import LabelEncoder +from scarches_api.utils.metrics import percentage_unknown from sklearn.model_selection import train_test_split @@ -61,6 +62,7 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), xgb_model.load_model(classifier_path) query.obs["prediction_xgb"] = le.inverse_transform(xgb_model.predict(query_latent.X)) + prediction_label = "prediction_xgb" if self.__classifier_knn: with open(encoding_path, "rb") as file: @@ -70,12 +72,22 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), knn_model = pickle.load(file) query.obs["prediction_knn"] = le.inverse_transform(knn_model.predict(query_latent.X)) + prediction_label = "prediction_knn" if self.__classifier_native is not None: if self.__model_class == sca.models.SCANVI.__class__: query.obs["prediction_scanvi"] = self.__classifier_native.predict(query) + prediction_label = "prediction_scanvi" if self.__model_class == sca.models.scPoli.__class__: query.obs["prediction_scpoli"] = self.__classifier_native.classify(query, scale_uncertainties=True) + prediction_label = "prediction_scpoli" + + # calculate the percentage of unknown cell types (cell types with uncertainty higher than 0.5) + percent_unknown = percentage_unknown(query, prediction_label) + print(percent_unknown) + + return percent_unknown + ''' Parameters @@ -89,8 +101,6 @@ def create_classifier(self, adata, latent_rep=False, model_path="", label_key="C if not os.path.exists(classifier_directory): os.makedirs(classifier_directory, exist_ok=True) - if query is not None: - validate_on_query=True train_data = Classifiers.__get_train_data( self, @@ -132,7 +142,7 @@ def create_classifier(self, adata, latent_rep=False, model_path="", label_key="C self, reports, classifier_directory=classifier_directory - ) + ) Classifiers.__save_eval_metrics_csv( self, @@ -179,7 +189,7 @@ def __get_train_data(self, adata, latent_rep=True, model_path=None): ''' def __split_train_data(self, train_data, input_adata, label_key, classifier_directory, validate_on_query=False): train_data['cell_type'] = input_adata.obs[label_key] - train_data['type'] = input_adata.obs["type"] + # train_data['type'] = input_adata.obs["type"] #Enable if at least one class has only 1 sample -> Error in stratification for validation set diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 8e1510a..a3391d5 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -118,13 +118,11 @@ def _map_query(self): self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - clust_pres_score=cluster_preservation_score(self._query_adata) - print(f"clust_pres_score: {clust_pres_score}") + self.clust_pres_score=cluster_preservation_score(self._query_adata) + print(f"clust_pres_score: {self.clust_pres_score}") - query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - print(f"query_with_anchor: {query_with_anchor}") - - # utils.notify_backend(self._webhook, {"clust_pres_score":clust_pres_score, "query_with_anchor":query_with_anchor}) + self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + print(f"query_with_anchor: {self.query_with_anchor}") def _acquire_data(self): @@ -207,7 +205,10 @@ def _transfer_labels(self): #Compute label transfer and save to respective .obs query_latent = scanpy.AnnData(self._query_adata.obsm["latent_rep"]) - clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) + percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) + + utils.notify_backend(self._webhook, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + def _concat_data(self): diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 1e4c267..b898cc6 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -158,14 +158,14 @@ def get_wknn( warnings.warn( "At least one of query2ref and ref2query should be True. Reset to default with both being True." ) - adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 + adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 # 1 if either R_i or Q_j are considered a nn of the other if ref2 is None: ref2 = ref adj_ref = build_nn(ref=ref2, k=k) - num_shared_neighbors = adj_q2r @ adj_ref + num_shared_neighbors = adj_q2r @ adj_ref # no. neighbours that Q_i and R_j have in common num_shared_neighbors_nn = num_shared_neighbors.multiply(adj_knn.T) # only keep weights if q and r are both nearest neigbours of eachother del num_shared_neighbors @@ -362,14 +362,14 @@ def entropy_of_labels(indices): stat = np.clip(stat, 0.00, 5.00) return stat - def percentage_unknown(query, prediction_label, uncertainty_threshold=0.5): - query.obs[f"{prediction_label}_filtered_by_uncert>0.5"] = query.obs[ - prediction_label - ].mask( - query.obs["uncertainty_mahalanobis"] > uncertainty_threshold, - "Unknown", - ) +def percentage_unknown(query, prediction_label, uncertainty_threshold=0.5): + query.obs[f"{prediction_label}_filtered_by_uncert>0.5"] = query.obs[ + prediction_label + ].mask( + query.obs["uncertainty_mahalanobis"] > uncertainty_threshold, + "Unknown", + ) - number_unknown = (query.obs["uncertainty_mahalanobis"] > uncertainty_threshold).sum() + number_unknown = (query.obs["uncertainty_mahalanobis"] > uncertainty_threshold).sum() - return number_unknown/len(query)*100 + return number_unknown/len(query)*100 From d2723c579ac88bc07c484511c3aee7970cec56e6 Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Tue, 14 May 2024 09:28:05 +0200 Subject: [PATCH 15/63] add webhook for metrics --- mapping/scarches_api/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index a3391d5..c5c7fe5 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -40,6 +40,7 @@ def __init__(self, configuration) -> None: self._reference_adata_path = get_from_config(configuration=configuration, key=parameters.REFERENCE_DATA_PATH) self._query_adata_path = get_from_config(configuration=configuration, key=parameters.QUERY_DATA_PATH) self._webhook = utils.get_from_config(configuration, parameters.WEBHOOK_RATIO) + self._webhook_metrics = utils.get_from_config(configuration, parameters.WEBHOOK_METRICS) # self._use_gpu = get_from_config(configuration=configuration, key=parameters.USE_GPU) #Has to be empty for the load_query_data function to work properly (looking for "model.pt") @@ -207,7 +208,7 @@ def _transfer_labels(self): percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) - utils.notify_backend(self._webhook, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) def _concat_data(self): From 04bd466d9b17aff82e71c2be6908baead706a3c0 Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Tue, 14 May 2024 09:29:35 +0200 Subject: [PATCH 16/63] add webhook_metrics parameter --- mapping/scarches_api/utils/parameters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index 03e96c4..f93e593 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -67,6 +67,7 @@ # sets the key for the webhook to call after the computation WEBHOOK = 'webhook' WEBHOOK_RATIO = 'webhook_ratio' +WEBHOOK_METRICS = 'webhook_metrics' # sets the path/s3 key of the pretrained model PRETRAINED_MODEL_PATH = 'model_path' # set scpoli model attr key @@ -104,4 +105,4 @@ # set if GPU is available # USE_GPU = 'use_gpu' # sets the used atlas -ATLAS = 'atlas' \ No newline at end of file +ATLAS = 'atlas' From b5bdbc5b13cbe0236855b935af1851f286bf50f4 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 14 May 2024 17:45:22 +0200 Subject: [PATCH 17/63] round metrics --- mapping/classifiers/classifiers.py | 6 +++--- mapping/scarches_api/models.py | 11 ++++++++++- mapping/scarches_api/utils/metrics.py | 6 +++--- mapping/scarches_api/utils/parameters.py | 1 + 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index 9ae7f83..20b8dd9 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -75,10 +75,10 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), prediction_label = "prediction_knn" if self.__classifier_native is not None: - if self.__model_class == sca.models.SCANVI.__class__: + if "SCANVI" in str(self.__model_class): query.obs["prediction_scanvi"] = self.__classifier_native.predict(query) prediction_label = "prediction_scanvi" - if self.__model_class == sca.models.scPoli.__class__: + else: query.obs["prediction_scpoli"] = self.__classifier_native.classify(query, scale_uncertainties=True) prediction_label = "prediction_scpoli" @@ -86,7 +86,7 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), percent_unknown = percentage_unknown(query, prediction_label) print(percent_unknown) - return percent_unknown + return round(percent_unknown, 2) ''' diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index a3391d5..b09fe7a 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -40,6 +40,7 @@ def __init__(self, configuration) -> None: self._reference_adata_path = get_from_config(configuration=configuration, key=parameters.REFERENCE_DATA_PATH) self._query_adata_path = get_from_config(configuration=configuration, key=parameters.QUERY_DATA_PATH) self._webhook = utils.get_from_config(configuration, parameters.WEBHOOK_RATIO) + self._webhook_metrics = utils.get_from_config(configuration, parameters.WEBHOOK_METRICS) # self._use_gpu = get_from_config(configuration=configuration, key=parameters.USE_GPU) #Has to be empty for the load_query_data function to work properly (looking for "model.pt") @@ -156,6 +157,7 @@ def _acquire_data(self): intersection = ref_vars.intersection(query_vars) inter_len = len(intersection) ratio = inter_len / len(ref_vars) + print(ratio) utils.notify_backend(self._webhook, {"ratio":ratio}) @@ -207,7 +209,7 @@ def _transfer_labels(self): percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) - utils.notify_backend(self._webhook, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) def _concat_data(self): @@ -517,7 +519,14 @@ def _map_query(self): self._query_adata) self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + + self.clust_pres_score=cluster_preservation_score(self._query_adata) + print(f"clust_pres_score: {self.clust_pres_score}") + self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + print(f"query_with_anchor: {self.query_with_anchor}") + + def _compute_latent_representation(self, explicit_representation, mean=False): explicit_representation.obsm["latent_rep"] = self._model.get_latent(explicit_representation, mean=mean) diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index b898cc6..6430106 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -84,7 +84,7 @@ def percent_query_with_anchor(ref_adata, query_adata): adj_mnn=build_mutual_nn(ref,query) has_anchor=adj_mnn.sum(0)>0 #all query cells that have an anchor (output dim: no query cells) percentage = (has_anchor.sum()/adj_mnn.shape[1])*100 - return percentage + return round(percentage, 2) def get_transition_prob_mat(dat, k=50, symm=True): @@ -200,7 +200,7 @@ def estimate_presence_score( weighting_scheme_wknn="jaccard_square", ref_trans_prop=None, use_rep_ref_trans_prop=None, - k_ref_trans_prop=50, + k_ref_trans_prop=15, symm_ref_trans_prop=True, split_by=None, do_random_walk=True, @@ -298,7 +298,7 @@ def estimate_presence_score( def cluster_preservation_score(adata, ds_amount=5000, type='standard'): """ - Calculate the cluster preservation score for an AnnData object after preprocessing. + Calculate the cluster preservation score for a query after mapping. Parameters: - adata: AnnData object containing single-cell data. - ds_amount: Maximum number of cells to include. diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index 03e96c4..a01c6bb 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -67,6 +67,7 @@ # sets the key for the webhook to call after the computation WEBHOOK = 'webhook' WEBHOOK_RATIO = 'webhook_ratio' +WEBHOOK_METRICS = 'webhook_metrics' # sets the path/s3 key of the pretrained model PRETRAINED_MODEL_PATH = 'model_path' # set scpoli model attr key From 8934be370cdd36e9b02c083b38ddfa6d3b7d4927 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 16 May 2024 10:48:32 +0200 Subject: [PATCH 18/63] fix spoli classifier error --- mapping/classifiers/classifiers.py | 4 ++- mapping/process/processing.py | 4 +-- mapping/scarches_api/models.py | 51 ++++++++++++------------------ 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index 20b8dd9..14248af 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -79,7 +79,9 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), query.obs["prediction_scanvi"] = self.__classifier_native.predict(query) prediction_label = "prediction_scanvi" else: - query.obs["prediction_scpoli"] = self.__classifier_native.classify(query, scale_uncertainties=True) + output=self.__classifier_native.classify(query, scale_uncertainties=True) + query.obs["prediction_scpoli"] = list(output.values())[0]["preds"] + query.obs["uncertainty_scpoli"] = list(output.values())[0]["uncert"] prediction_label = "prediction_scpoli" # calculate the percentage of unknown cell types (cell types with uncertainty higher than 0.5) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 7171135..1ecf283 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -641,8 +641,8 @@ def __prepare_output(latent_adata: sc.AnnData, combined_adata: sc.AnnData, confi sc.pp.neighbors(combined_adata, n_neighbors, use_rep="latent_rep") print("neighbors") - sc.tl.leiden(combined_adata) - print("leiden") + # sc.tl.leiden(combined_adata) + # print("leiden") sc.tl.umap(combined_adata) print("umap") diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index b09fe7a..9505e05 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -17,6 +17,7 @@ from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 from utils.utils import read_h5ad_file_from_s3 +import pandas as pd from process.processing import Preprocess from process.processing import Postprocess @@ -97,20 +98,6 @@ def _map_query(self): #Save out the latent representation for QUERY self._compute_latent_representation(explicit_representation=self._query_adata) - #save .X and var_names of query in new adata for later concatenation after cellxgene - self.adata_query_X = scanpy.AnnData(self._query_adata.X.copy()) - self.adata_query_X.var_names = self._query_adata.var_names - - #we can then zero out .X in original query - if self._query_adata.X.format == "csc": - all_zeros = csc_matrix(self._query_adata.X.shape) - else: - all_zeros = csr_matrix(self._query_adata.X.shape) - - del self._query_adata.X - - self._query_adata.X = all_zeros.copy() - # Calculate presence score presence_score=estimate_presence_score( @@ -213,6 +200,18 @@ def _transfer_labels(self): def _concat_data(self): + + #save .X and var_names of query in new adata for later concatenation after cellxgene + self.adata_query_X = scanpy.AnnData(self._query_adata.X.copy()) + self.adata_query_X.var_names = self._query_adata.var_names + + #we can then zero out .X in original query + if self._query_adata.X.format == "csc": + all_zeros = csc_matrix(self._query_adata.X.shape) + else: + all_zeros = csr_matrix(self._query_adata.X.shape) + + self._query_adata.X = all_zeros.copy() self.latent_full_from_mean_var = np.concatenate((self._reference_adata.obsm["latent_rep"], self._query_adata.obsm["latent_rep"])) @@ -266,10 +265,10 @@ def _concat_data(self): #Concatenate on disk to save memory experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name) - # query_obs=set(self._query_adata.obs.columns) - # ref_obs=set(self._reference_adata.obs.columns) - # inter = ref_obs.intersection(query_obs) - # new_columns = query_obs.union(inter) + query_obs_columns=set(self._query_adata.obs.columns) + ref_obs_columns=set(self._reference_adata.obs.columns) + columns_only_query = query_obs_columns.difference(ref_obs_columns) + query_obs = self._query_adata.obs[columns_only_query].copy() del self._reference_adata del self._query_adata @@ -286,6 +285,8 @@ def _concat_data(self): self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var self._combined_adata.obs["presence_score"] = self.presence_score + + self._combined_adata.obs=pd.concat([self._combined_adata.obs,query_obs], axis=1) print("added latent rep to adata") @@ -467,7 +468,7 @@ def _map_query(self): self._model = model self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCPOLI_MAX_EPOCHS) - model.train( + self._model.train( n_epochs=self._max_epochs, pretraining_epochs=40, eta=10 @@ -502,18 +503,6 @@ def _map_query(self): - #save .X and var_names of query in new adata for later concatenation after cellxgene - self.adata_query_X = scanpy.AnnData(self._query_adata.X.copy()) - self.adata_query_X.var_names = self._query_adata.var_names - - #we can then zero out .X in original query - if self._query_adata.X.format == "csc": - all_zeros = csc_matrix(self._query_adata.X.shape) - else: - all_zeros = csr_matrix(self._query_adata.X.shape) - - self._query_adata.X = all_zeros.copy() - presence_score = estimate_presence_score( self._reference_adata, self._query_adata) From 8f9162124748124e80777c226f4f03df0923bd30 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 18 May 2024 05:53:47 +0200 Subject: [PATCH 19/63] add stress score --- mapping/process/processing.py | 4 ++-- mapping/scarches_api/models.py | 11 +++++++++-- mapping/scarches_api/utils/metrics.py | 11 ++++++++++- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 1ecf283..5539871 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -294,8 +294,8 @@ def get_keys(atlas, target_adata): although they contain the same information. """ #TODO: keys are stored on db, get rid of this hardcoding!!!! - #Set unlabeled key to always be "Unlabeled" - unlabeled_key = "Unlabeled" + #Set unlabeled key to always be "unlabeled" + unlabeled_key = "unlabeled" if atlas == 'pbmc': diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 9505e05..8234ca3 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -13,7 +13,7 @@ from utils import utils from utils import parameters -from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor +from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor, stress_score from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 from utils.utils import read_h5ad_file_from_s3 @@ -285,6 +285,8 @@ def _concat_data(self): self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var self._combined_adata.obs["presence_score"] = self.presence_score + + self._combined_adata.obs_names_make_unique() self._combined_adata.obs=pd.concat([self._combined_adata.obs,query_obs], axis=1) @@ -298,6 +300,11 @@ def _compute_latent_representation(self, explicit_representation): explicit_representation.obsm["latent_rep"] = self._model.get_latent_representation(explicit_representation) def _save_data(self): + + if self._atlas=="hnoca": + print("calculating stress score") + stress_score(self._combined_adata) + print(self._combined_adata.obs["Hallmark_Glycolysis"]) combined_downsample = self.downsample_adata() #Save output @@ -448,7 +455,7 @@ def _acquire_data(self): def _compute_latent_representation(self, explicit_representation): #Setup adata before quering model for latent representation - scarches.models.SCANVI.setup_anndata(explicit_representation, labels_key=self._cell_type_key, unlabeled_category="Unlabeled", batch_key=self._batch_key) + scarches.models.SCANVI.setup_anndata(explicit_representation, labels_key=self._cell_type_key, unlabeled_category="unlabeled", batch_key=self._batch_key) super()._compute_latent_representation(explicit_representation=explicit_representation) diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 6430106..67b046d 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -312,7 +312,7 @@ def cluster_preservation_score(adata, ds_amount=5000, type='standard'): # sc.pp.scale(adata, zero_center=True) if adata.n_obs > ds_amount: adata = adata[np.random.choice(adata.obs_names, ds_amount, replace=False), :] - sc.tl.pca(adata, svd_solver='arpack', n_comps=dims) + sc.pp.pca(adata, svd_solver='arpack', n_comps=dims, use_highly_variable=False) sc.pp.neighbors(adata, n_neighbors=15, n_pcs=dims, method='umap', metric='euclidean') sc.tl.leiden(adata, resolution=0.6) # sc.pp.neighbors(adata, use_rep='latent_rep', key_added="integrated_neighbors") @@ -347,7 +347,9 @@ def entropy_of_labels(indices): return rel_entr(counts, np.full_like(counts, fill_value=1/len(counts))).sum() orig_ent = np.array([entropy_of_labels(idx) for idx in orig_indices]) + print(f"orig_ent: {orig_ent}") integrated_ent = np.array([entropy_of_labels(idx) for idx in integrated_indices]) + print(f"integrated_ent: {integrated_ent}") # Calculate the cluster preservation statistic ids = adata.obs['leiden'].to_numpy() @@ -373,3 +375,10 @@ def percentage_unknown(query, prediction_label, uncertainty_threshold=0.5): number_unknown = (query.obs["uncertainty_mahalanobis"] > uncertainty_threshold).sum() return number_unknown/len(query)*100 + + +def stress_score(adata): + + msigdb_glycolysis = np.array(pd.read_csv('https://www.gsea-msigdb.org/gsea/msigdb/human/download_geneset.jsp?geneSetName=HALLMARK_GLYCOLYSIS&fileType=TSV', sep='\t', header=None, index_col=0).loc['GENE_SYMBOLS',1].split(',')) + msigdb_glycolysis = np.intersect1d(msigdb_glycolysis, adata.var_names) + sc.tl.score_genes(adata, msigdb_glycolysis, score_name='Hallmark_Glycolysis') \ No newline at end of file From 470b04a975e276b21f7e2240c3b1aa32da70d95f Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 21 May 2024 07:47:54 +0200 Subject: [PATCH 20/63] change stress score --- mapping/scarches_api/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 8234ca3..02dcf15 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -165,6 +165,11 @@ def _eval_mapping(self): classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) + #stress score + if self._atlas=="hnoca": + print("calculating stress score") + stress_score(self._query_adata) + print(self._query_adata.obs["Hallmark_Glycolysis"]) def _transfer_labels(self): if not self._clf_native and not self._clf_knn and not self._clf_xgb: @@ -300,11 +305,6 @@ def _compute_latent_representation(self, explicit_representation): explicit_representation.obsm["latent_rep"] = self._model.get_latent_representation(explicit_representation) def _save_data(self): - - if self._atlas=="hnoca": - print("calculating stress score") - stress_score(self._combined_adata) - print(self._combined_adata.obs["Hallmark_Glycolysis"]) combined_downsample = self.downsample_adata() #Save output From 62a068472bc720169acb3bd5f8f58596d75012b8 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 23 May 2024 08:31:02 +0200 Subject: [PATCH 21/63] fix cluster pres score --- mapping/scarches_api/utils/metrics.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 67b046d..2c96292 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -13,6 +13,7 @@ import importlib.util import argparse from scipy.sparse import csr_matrix +from sklearn.neighbors import NearestNeighbors warnings.filterwarnings("ignore") @@ -290,11 +291,7 @@ def estimate_presence_score( # cluster preservation score -import numpy as np -import scanpy as sc -from sklearn.neighbors import NearestNeighbors -from scipy.special import rel_entr -import pandas as pd + def cluster_preservation_score(adata, ds_amount=5000, type='standard'): """ @@ -344,7 +341,7 @@ def cluster_preservation_score(adata, ds_amount=5000, type='standard'): def entropy_of_labels(indices): labels = adata.obs['leiden'][indices].to_numpy() _, counts = np.unique(labels, return_counts=True) - return rel_entr(counts, np.full_like(counts, fill_value=1/len(counts))).sum() + return (counts*np.log(counts/(1/len(counts)))).sum() orig_ent = np.array([entropy_of_labels(idx) for idx in orig_indices]) print(f"orig_ent: {orig_ent}") From 0439ab6e294431e3e7b7134664b7728ba0968a48 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 1 Jun 2024 18:28:09 +0200 Subject: [PATCH 22/63] add distributed training to scvi models and use speed improvement branch scpoli --- mapping/Dockerfile | 3 ++- mapping/scarches_api/models.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mapping/Dockerfile b/mapping/Dockerfile index 177ad31..4cca1a0 100644 --- a/mapping/Dockerfile +++ b/mapping/Dockerfile @@ -19,7 +19,8 @@ RUN apt-get update && \ #pip install git+https://github.com/theislab/pertpy RUN apt-get install -y git && \ - pip install git+https://github.com/theislab/scarches.git + # pip install git+https://github.com/theislab/scarches.git + pip install git+https://github.com/theislab/scarches.git@speed_improvement_merge ENV APP_HOME /app WORKDIR $APP_HOME diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 02dcf15..ed74587 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -7,10 +7,12 @@ import torch import gc import numpy as np +import time import scipy from scipy.sparse import csr_matrix, csc_matrix from anndata import experimental from utils import utils +from scvi.dataloaders import BatchDistributedSampler from utils import parameters from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor, stress_score @@ -32,6 +34,8 @@ class ArchmapBaseModel(): def __init__(self, configuration) -> None: self._configuration = configuration + start_time = time.time() + self._atlas = get_from_config(configuration=configuration, key=parameters.ATLAS) self._model_type = get_from_config(configuration=configuration, key=parameters.MODEL) self._model_path = get_from_config(configuration=configuration, key=parameters.PRETRAINED_MODEL_PATH) @@ -69,6 +73,9 @@ def __init__(self, configuration) -> None: self._clf_model_path = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_PATH) self._clf_encoding_path = get_from_config(configuration=configuration, key=parameters.ENCODING_PATH) + end_time = time.time() + print(f"time {end_time-start_time}") + def run(self): self._map_query() self._eval_mapping() @@ -79,14 +86,24 @@ def run(self): def _map_query(self): #Map the query onto reference + start_time = time.time() + # threshold = 10000 self._model.train( max_epochs=self._max_epochs, plan_kwargs=dict(weight_decay=0.0), check_val_every_n_epoch=10, - # use_gpu=self._use_gpu + datasplitter_kwargs = dict(distributed_sampler = True), + strategy='ddp_find_unused_parameters_true', + accelerator="cpu", + devices=4 ) + end_time = time.time() + print(f"time {end_time-start_time}") + + + if "X_latent_qzm" in self._reference_adata.obsm and "X_latent_qzv" in self._reference_adata.obsm: print("__________getting X_latent_qzm from minified atlas for scvi-tools models___________") qzm = self._reference_adata.obsm["X_latent_qzm"] @@ -474,7 +491,7 @@ def _map_query(self): self._model = model self._max_epochs = get_from_config(configuration=self._configuration, key=parameters.SCPOLI_MAX_EPOCHS) - + self._model.train( n_epochs=self._max_epochs, pretraining_epochs=40, From 4b7ba0b25d9d22bc86f5c1ccaa17885a96e58e81 Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Sun, 2 Jun 2024 07:50:21 +0200 Subject: [PATCH 23/63] change number of devices --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index ed74587..4dce515 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -96,7 +96,7 @@ def _map_query(self): datasplitter_kwargs = dict(distributed_sampler = True), strategy='ddp_find_unused_parameters_true', accelerator="cpu", - devices=4 + devices=-1 ) end_time = time.time() From 597012e436ab004a7b1ffeda4afa231581e8e3e7 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 2 Jun 2024 08:21:35 +0200 Subject: [PATCH 24/63] change to absolute path --- mapping/gcsfuse_run.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) mode change 100644 => 100755 mapping/gcsfuse_run.sh diff --git a/mapping/gcsfuse_run.sh b/mapping/gcsfuse_run.sh old mode 100644 new mode 100755 index 728765c..b29ae6a --- a/mapping/gcsfuse_run.sh +++ b/mapping/gcsfuse_run.sh @@ -43,6 +43,8 @@ ls -l $MNT_DIR; echo "the pwd is now: " pwd; +ABSOLUTE_PATH=$(pwd) + #run the api -gunicorn --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 --chdir ./scarches_api/ api:app \ No newline at end of file +gunicorn --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 --chdir $ABSOLUTE_PATH/scarches_api/ api:app \ No newline at end of file From 5a654cf3146ce278ac945ac3371cd0612404b057 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 2 Jun 2024 09:39:12 +0200 Subject: [PATCH 25/63] change number of devices --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 4dce515..ed74587 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -96,7 +96,7 @@ def _map_query(self): datasplitter_kwargs = dict(distributed_sampler = True), strategy='ddp_find_unused_parameters_true', accelerator="cpu", - devices=-1 + devices=4 ) end_time = time.time() From fdc909d24b89edd3fc808522e91e6e79d5c9523d Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 3 Jun 2024 14:44:20 +0200 Subject: [PATCH 26/63] add X back for cxg --- mapping/scarches_api/init.py | 122 ++++---------------------- mapping/scarches_api/models.py | 109 ++++++++++++++++++----- mapping/scarches_api/utils/metrics.py | 4 +- mapping/scarches_api/utils/utils.py | 50 +++++++++++ 4 files changed, 152 insertions(+), 133 deletions(-) diff --git a/mapping/scarches_api/init.py b/mapping/scarches_api/init.py index ba68475..2d2105c 100644 --- a/mapping/scarches_api/init.py +++ b/mapping/scarches_api/init.py @@ -1,7 +1,7 @@ import os import time -startTime = time.time() + from utils import utils, parameters @@ -108,7 +108,7 @@ def query(user_config): :param user_config: keys of config parsed from the rest api :return: config """ - + start_time2 = time.time() print("got config " + str(user_config)) start_time = time.time() configuration = merge_configs(user_config) @@ -139,6 +139,13 @@ def query(user_config): mapping = ScPoli(configuration=configuration) mapping.run() + end_time2 = time.time() + + + print(f"time end: {end_time2}-{start_time2}") + + + #TODO: add obsm and other keys from query_adata_raw to adata_combined for download # atlas_name = utils.get_from_config(configuration, parameters.ATLAS) @@ -147,113 +154,14 @@ def query(user_config): if get_from_config(configuration, parameters.WEBHOOK) is not None and len( get_from_config(configuration, parameters.WEBHOOK)) > 0: utils.notify_backend(get_from_config(configuration, parameters.WEBHOOK), configuration) - if ("counts" not in mapping._combined_adata.layers or mapping._combined_adata.layers["counts"].size == 0): - if not mapping._reference_adata_path.endswith("data.h5ad"): - raise ValueError("The reference data should be named data.h5ad") - else: - count_matrix_path = mapping._reference_adata_path[:-len("data.h5ad")] + "data_only_count.h5ad" - combined_adata = mapping._combined_adata - - cxg_with_count_path = get_from_config(configuration, parameters.OUTPUT_PATH)[:-len("cxg.h5ad")] + "cxg_with_count.h5ad" - count_matrix_size_gb = get_file_size_in_gb(count_matrix_path) - temp_output = tempfile.mktemp( suffix=".h5ad") - - if count_matrix_size_gb < 10: - print("Count matrix size less than 10 gb.") - count_matrix = read_h5ad_file_from_s3(count_matrix_path) - #Added because concat_on_disk only allows csr concat - if count_matrix.X.format == "csc" or mapping.adata_query_X.X.format == "csc": - print("Concatenating query and reference count matrices in memory") - combined_data_X = count_matrix.concatenate(mapping.adata_query_X) - - del count_matrix - del mapping.adata_query_X - gc.collect() - - else: - print("Concatenating query and reference count matrices on disk") - #Create temp files on disk - temp_reference = tempfile.NamedTemporaryFile(suffix=".h5ad") - temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") - temp_combined = tempfile.NamedTemporaryFile(suffix=".h5ad") - - #Write data to temp files - count_matrix.write_h5ad(temp_reference.name) - mapping.adata_query_X.write_h5ad(temp_query.name) - - del count_matrix - del mapping.adata_query_X - gc.collect() - - experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name) - combined_data_X = sc.read_h5ad(temp_combined.name) - - combined_adata.X = combined_data_X.X - sc.write(temp_output, combined_adata) - - else: - print("Count matrix size larger than 10 gb.") - temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") - mapping.adata_query_X.write_h5ad(temp_query.name) - del mapping.adata_query_X - gc.collect() - temp_output=replace_X_on_disk(combined_adata,temp_output, temp_query.name, count_matrix_path) - - print("Concatenated matrices") - print("cxg_with_count_path written to: " + temp_output) - print("storing cxg_with_count_path to gcp with output path: " + cxg_with_count_path) - utils.store_file_in_s3(temp_output, cxg_with_count_path) - print("Stored adata with counts on cloud") - utils.notify_backend(get_from_config(configuration, parameters.WEBHOOK), configuration) - - return configuration - -def replace_X_on_disk(combined_adata,temp_output, query_X_file, ref_count_matrix_path): - """ - Writes combined_adata to disk, fetches another .h5ad file specified by ref_count_matrix_path. - Concatenates the .X of the fetched file with query_X_file. - Writes concatenated adata metadata (eg .obs and .vars and .varm and .obsm and uns) to adata with .X on disk - - :param combined_adata: The Anndata object in memory. - :param temp_output: Temp file to write combined_adata to. - :param query_X_file: File with .X for query. - :param ref_count_matrix_path: The S3 key for the .h5ad file to be fetched. - - Returns: File path to saved adata with concatenated metadata and .X - """ + cxg_with_count_path = get_from_config(configuration, parameters.OUTPUT_PATH)[:-len("cxg.h5ad")] + "cxg_with_count.h5ad" + print("storing cxg_with_count_path to gcp with output path: " + cxg_with_count_path) + utils.store_file_in_s3(mapping.temp_output_combined, cxg_with_count_path) + print("Stored adata with counts on cloud") + utils.notify_backend(get_from_config(configuration, parameters.WEBHOOK), configuration) - # Write combined_adata to disk - combined_adata.write(temp_output) - print(f"combined_adata written to {temp_output}") - # Fetch the new file and get its path - temp_ref_count_matrix_path = fetch_file_to_temp_path_from_s3(ref_count_matrix_path) - if temp_ref_count_matrix_path is None: - print("No file fetched. Exiting.") - return - print("Fetched reference adata with counts") - temp_combined = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) - - experimental.concat_on_disk([temp_ref_count_matrix_path, query_X_file], temp_combined.name) - print("Concatenated reference and query counts") - - # write concatenated adata metadata (eg .obs and .vars and .varm and .obsm and uns) to adata with .X on disk - with h5py.File(temp_output, "r") as f: - with h5py.File(temp_combined.name, 'r+') as target_file: - - elems=list(f.keys()) - if "X" in elems: - elems.remove("X") - - for elem in elems: - v=read_elem(f[elem]) - if isinstance(v,dict) and not bool(v): - continue - - write_elem(target_file, f"{elem}", v) - print("Added concatenated metadata to anndata with full .X on disk") - - return temp_combined.name + return configuration diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 02dcf15..6d3286f 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -11,12 +11,13 @@ from scipy.sparse import csr_matrix, csc_matrix from anndata import experimental from utils import utils +import scanpy as sc from utils import parameters from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor, stress_score from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 -from utils.utils import read_h5ad_file_from_s3 +from utils.utils import read_h5ad_file_from_s3, get_file_size_in_gb, replace_X_on_disk import pandas as pd from process.processing import Preprocess @@ -87,6 +88,8 @@ def _map_query(self): # use_gpu=self._use_gpu ) + utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + if "X_latent_qzm" in self._reference_adata.obsm and "X_latent_qzv" in self._reference_adata.obsm: print("__________getting X_latent_qzm from minified atlas for scvi-tools models___________") qzm = self._reference_adata.obsm["X_latent_qzm"] @@ -118,28 +121,17 @@ def _acquire_data(self): self._reference_adata = read_h5ad_file_from_s3(self._reference_adata_path) self._reference_adata.obs["type"] = "reference" - self._query_adata = read_h5ad_file_from_s3(self._query_adata_path) - self._query_adata.obs["type"] = "query" - - - # #Check if cell_type_key exists in query - # if self._cell_type_key not in self._query_adata.obs.columns: - # self._query_adata.obs[self._cell_type_key] = "Unlabeled" + self._query_adata_raw = read_h5ad_file_from_s3(self._query_adata_path) + self._query_adata_raw.obs["type"] = "query" - # #Check if batch_key exists in query - # if self._batch_key not in self._query_adata.obs.columns: - # self._query_adata.obs[self._batch_key] = "query_batch" - - #Store counts in layer if not stored already - if "counts" not in self._query_adata.layers.keys(): - self._query_adata.layers['counts'] = self._query_adata.X + self._query_adata_raw.obs_names_make_unique() #Convert bool to categorical to avoid write error during concatenation Preprocess.bool_to_categorical(self._reference_adata) - Preprocess.bool_to_categorical(self._query_adata) + Preprocess.bool_to_categorical(self._query_adata_raw) ref_vars = self._reference_adata.var_names - query_vars = self._query_adata.var_names + query_vars = self._query_adata_raw.var_names intersection = ref_vars.intersection(query_vars) inter_len = len(intersection) @@ -148,12 +140,18 @@ def _acquire_data(self): utils.notify_backend(self._webhook, {"ratio":ratio}) - self._query_adata.obs_names_make_unique() - self._query_adata.var_names_make_unique() - + + # save only necessary data for mapping to new adata + self._query_adata = self._query_adata_raw.copy() + del self._query_adata.varm + del self._query_adata.obsm + del self._query_adata.layers + del self._query_adata.uns + del self._query_adata.obsp + del self._query_adata.varp + + self._query_adata.layers['counts'] = self._query_adata.X - #Remove later - for testing only - # self._reference_adata = scanpy.pp.subsample(self._reference_adata, 0.1, copy=True) def _eval_mapping(self): #Create AnnData objects off the latent representation @@ -169,7 +167,7 @@ def _eval_mapping(self): if self._atlas=="hnoca": print("calculating stress score") stress_score(self._query_adata) - print(self._query_adata.obs["Hallmark_Glycolysis"]) + print(self._query_adata.obs["Hallmark_Glycolysis_Score"]) def _transfer_labels(self): if not self._clf_native and not self._clf_knn and not self._clf_xgb: @@ -305,11 +303,74 @@ def _compute_latent_representation(self, explicit_representation): explicit_representation.obsm["latent_rep"] = self._model.get_latent_representation(explicit_representation) def _save_data(self): - + # add .X to self._combined_adata + + print("adding X from cloud") + self.add_X_from_cloud() + combined_downsample = self.downsample_adata() + #Save output Postprocess.output(None, combined_downsample, self._configuration) + def add_X_from_cloud(self): + if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( + get_from_config(self._configuration, parameters.WEBHOOK)) > 0: + + utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + if not self._reference_adata_path.endswith("data.h5ad"): + raise ValueError("The reference data should be named data.h5ad") + else: + count_matrix_path = self._reference_adata_path[:-len("data.h5ad")] + "data_only_count.h5ad" + + combined_adata = self._combined_adata + count_matrix_size_gb = get_file_size_in_gb(count_matrix_path) + self.temp_output_combined = tempfile.mktemp( suffix=".h5ad") + + if count_matrix_size_gb < 10: + print("Count matrix size less than 10 gb.") + count_matrix = read_h5ad_file_from_s3(count_matrix_path) + #Added because concat_on_disk only allows csr concat + if count_matrix.X.format == "csc" or self.adata_query_X.X.format == "csc": + print("Concatenating query and reference count matrices in memory") + combined_data_X = count_matrix.concatenate(self.adata_query_X) + + del count_matrix + del self.adata_query_X + gc.collect() + + else: + print("Concatenating query and reference count matrices on disk") + #Create temp files on disk + temp_reference = tempfile.NamedTemporaryFile(suffix=".h5ad") + temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") + temp_combined = tempfile.NamedTemporaryFile(suffix=".h5ad") + + #Write data to temp files + count_matrix.write_h5ad(temp_reference.name) + self.adata_query_X.write_h5ad(temp_query.name) + + del count_matrix + del self.adata_query_X + gc.collect() + + experimental.concat_on_disk([temp_reference.name, temp_query.name], temp_combined.name) + combined_data_X = sc.read_h5ad(temp_combined.name) + + combined_adata.X = combined_data_X.X + sc.write(self.temp_output_combined, combined_adata) + + else: + print("Count matrix size larger than 10 gb.") + temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") + self.adata_query_X.write_h5ad(temp_query.name) + del self.adata_query_X + gc.collect() + self.temp_output_combined =replace_X_on_disk(combined_adata,self.temp_output_combined, temp_query.name, count_matrix_path) + + self._combined_adata = combined_adata + + def downsample_adata(self, query_ratio=5): """ Downsamples the reference data to be proportional to the query data. diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 2c96292..95678a0 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -302,7 +302,7 @@ def cluster_preservation_score(adata, ds_amount=5000, type='standard'): Returns: - score: Cluster preservation score. """ - dims = min(50, adata.uns.get('Azimuth_map_ndims', 50)) + dims = 50 if type == 'standard': # Following the standard preprocessing workflow @@ -378,4 +378,4 @@ def stress_score(adata): msigdb_glycolysis = np.array(pd.read_csv('https://www.gsea-msigdb.org/gsea/msigdb/human/download_geneset.jsp?geneSetName=HALLMARK_GLYCOLYSIS&fileType=TSV', sep='\t', header=None, index_col=0).loc['GENE_SYMBOLS',1].split(',')) msigdb_glycolysis = np.intersect1d(msigdb_glycolysis, adata.var_names) - sc.tl.score_genes(adata, msigdb_glycolysis, score_name='Hallmark_Glycolysis') \ No newline at end of file + sc.tl.score_genes(adata, msigdb_glycolysis, score_name='Hallmark_Glycolysis_Score') \ No newline at end of file diff --git a/mapping/scarches_api/utils/utils.py b/mapping/scarches_api/utils/utils.py index c48d7bf..24c43a6 100644 --- a/mapping/scarches_api/utils/utils.py +++ b/mapping/scarches_api/utils/utils.py @@ -14,6 +14,9 @@ import pandas as pd # from scarches.dataset.trvae.data_handling import remove_sparsity import traceback +from anndata import experimental +from anndata.experimental import write_elem, read_elem +import h5py UNWANTED_LABELS = ['leiden', '', '_scvi_labels', '_scvi_batch'] @@ -665,3 +668,50 @@ def fetch_file_to_temp_path_from_s3(key): return filename + +def replace_X_on_disk(combined_adata,temp_output, query_X_file, ref_count_matrix_path): + """ + Writes combined_adata to disk, fetches another .h5ad file specified by ref_count_matrix_path. + Concatenates the .X of the fetched file with query_X_file. + Writes concatenated adata metadata (eg .obs and .vars and .varm and .obsm and uns) to adata with .X on disk + + :param combined_adata: The Anndata object in memory. + :param temp_output: Temp file to write combined_adata to. + :param query_X_file: File with .X for query. + :param ref_count_matrix_path: The S3 key for the .h5ad file to be fetched. + + Returns: File path to saved adata with concatenated metadata and .X + """ + + # Write combined_adata to disk + combined_adata.write(temp_output) + print(f"combined_adata written to {temp_output}") + # Fetch the new file and get its path + temp_ref_count_matrix_path = fetch_file_to_temp_path_from_s3(ref_count_matrix_path) + if temp_ref_count_matrix_path is None: + print("No file fetched. Exiting.") + return + print("Fetched reference adata with counts") + temp_combined = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) + + experimental.concat_on_disk([temp_ref_count_matrix_path, query_X_file], temp_combined.name) + print("Concatenated reference and query counts") + + # write concatenated adata metadata (eg .obs and .vars and .varm and .obsm and uns) to adata with .X on disk + with h5py.File(temp_output, "r") as f: + with h5py.File(temp_combined.name, 'r+') as target_file: + + elems=list(f.keys()) + if "X" in elems: + elems.remove("X") + + for elem in elems: + v=read_elem(f[elem]) + if isinstance(v,dict) and not bool(v): + continue + + write_elem(target_file, f"{elem}", v) + print("Added concatenated metadata to anndata with full .X on disk") + + return temp_combined.name + From 88f403d2a2788fc6ad95cf4ae3df2487aa513303 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 4 Jun 2024 07:00:57 +0200 Subject: [PATCH 27/63] remove threading --- mapping/scarches_api/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mapping/scarches_api/api.py b/mapping/scarches_api/api.py index 2315861..f3f3328 100644 --- a/mapping/scarches_api/api.py +++ b/mapping/scarches_api/api.py @@ -21,8 +21,9 @@ def query(): run_async = get_from_config(config, parameters.RUN_ASYNCHRONOUSLY) if run_async is not None and run_async: actual_config = scarches.merge_configs(config) - thread = Thread(target=scarches.query, args=(config,)) - thread.start() + # thread = Thread(target=scarches.query, args=(config,)) + # thread.start() + actual_config = scarches.query(config) return actual_config, 200 else: actual_configuration = scarches.query(config) From da89dede8075c52669e951111b481aa0f1daabcd Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 4 Jun 2024 12:26:44 +0200 Subject: [PATCH 28/63] update get keys function --- mapping/process/processing.py | 164 ++++++++++++++++++--------------- mapping/scarches_api/models.py | 3 +- 2 files changed, 92 insertions(+), 75 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 5539871..ff6b0d8 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -287,77 +287,72 @@ def bool_to_categorical(adata): if adata.var[col].dtype.name == "bool" or adata.var[col].dtype.name == "object": adata.var[col] = adata.var[col].astype("category") - def get_keys(atlas, target_adata): - """ - Sets the batch(condition) and cell_type keys, according to the atlas chosen. - This is necessary as the reference files all have these keys under different names, - although they contain the same information. - """ - #TODO: keys are stored on db, get rid of this hardcoding!!!! - #Set unlabeled key to always be "unlabeled" - unlabeled_key = "unlabeled" - - - if atlas == 'pbmc': - cell_type_key = 'cell_type_for_integration' - batch_key = 'sample_ID_lataq' - elif atlas == 'heart': - cell_type_key = 'cell_type' - batch_key = 'donor' - elif atlas == 'hlca': - cell_type_key = 'scanvi_label' - batch_key = 'dataset' - elif atlas == 'hlca_retrained': - cell_type_key = 'ann_finest_level' - batch_key = 'sample' - elif atlas == 'retina': - cell_type_key = 'CellType' - batch_key = 'batch' - elif atlas == 'fetal_immune': - cell_type_key = 'celltype_annotation' - batch_key = 'bbk' - elif atlas == "nsclc": - cell_type_key = 'cell_type' - batch_key = 'sample' - elif atlas == "gb": - cell_type_key = 'CellID' - batch_key = 'author' - elif atlas == "hypomap": - cell_type_key = 'Author_CellType' - batch_key = 'Batch_ID' - elif atlas == "pancreas": - cell_type_key = "cell_type" - batch_key = "batch_integration" - elif atlas == "HRCA": - cell_type_key = "cell_type_scarches" - batch_key = "batch_donor_asset" - elif atlas == "hnoca": - cell_type_key = "snapseed_pca_rss_level_123" - batch_key = "batch" - elif atlas == "heoca": - cell_type_key = "cell_type" - batch_key = "sample_id" - elif atlas == "fetal_brain": - cell_type_key = "subregion_class" - batch_key = "batch" + # def get_keys(atlas, target_adata): + # """ + # Sets the batch(condition) and cell_type keys, according to the atlas chosen. + # This is necessary as the reference files all have these keys under different names, + # although they contain the same information. + # """ + # #TODO: keys are stored on db, get rid of this hardcoding!!!! + # #Set unlabeled key to always be "unlabeled" + # unlabeled_key = "unlabeled" + + + # if atlas == 'pbmc': + # cell_type_key = 'cell_type_for_integration' + # batch_key = 'sample_ID_lataq' + # elif atlas == 'heart': + # cell_type_key = 'cell_type' + # batch_key = 'donor' + # elif atlas == 'hlca': + # cell_type_key = 'scanvi_label' + # batch_key = 'dataset' + # elif atlas == 'hlca_retrained': + # cell_type_key = 'ann_finest_level' + # batch_key = 'sample' + # elif atlas == 'retina': + # cell_type_key = 'CellType' + # batch_key = 'batch' + # elif atlas == 'fetal_immune': + # cell_type_key = 'celltype_annotation' + # batch_key = 'bbk' + # elif atlas == "nsclc": + # cell_type_key = 'cell_type' + # batch_key = 'sample' + # elif atlas == "gb": + # cell_type_key = 'CellID' + # batch_key = 'author' + # elif atlas == "hypomap": + # cell_type_key = 'Author_CellType' + # batch_key = 'Batch_ID' + # elif atlas == "pancreas": + # cell_type_key = "cell_type" + # batch_key = "batch_integration" + # elif atlas == "HRCA": + # cell_type_key = "cell_type_scarches" + # batch_key = "batch_donor_asset" + # elif atlas == "hnoca": + # cell_type_key = "snapseed_pca_rss_level_123" + # batch_key = "batch" + # elif atlas == "heoca": + # cell_type_key = "cell_type" + # batch_key = "sample_id" + # elif atlas == "fetal_brain": + # cell_type_key = "subregion_class" + # batch_key = "batch" - #Check if provided query contains respective labels - if cell_type_key not in target_adata.obs.columns or batch_key not in target_adata.obs.columns: - raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") + # #Check if provided query contains respective labels + # if cell_type_key not in target_adata.obs.columns or batch_key not in target_adata.obs.columns: + # raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") - return cell_type_key, batch_key, unlabeled_key + # return cell_type_key, batch_key, unlabeled_key - def __get_keys_model(configuration): - #Get relative model path - model_path = "assets/" + utils.get_from_config(configuration, parameters.MODEL) + "/" + utils.get_from_config(configuration, parameters.ATLAS) + "/" - - #Get label names the model was set up with - attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] + def _get_keys(target_adata, configuration): # try: # attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] @@ -374,20 +369,41 @@ def __get_keys_model(configuration): condition_key_model = None unlabeled_key_model = None - if("registry_" not in attr_dict): - data_registry = attr_dict["scvi_setup_dict_"]["categorical_mappings"] + model_type = utils.get_from_config(configuration, parameters.MODEL) - cell_type_key_model = data_registry["_scvi_labels"]["original_key"] - condition_key_model = data_registry["_scvi_batch"]["original_key"] - else: - data_registry = attr_dict["registry_"]["field_registries"] + if model_type in ["scANVI","scVI"]: + + model_path = utils.get_from_config(configuration, parameters.PRETRAINED_MODEL_PATH) + attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] - cell_type_key_model = data_registry["labels"]["state_registry"]["original_key"] - condition_key_model = data_registry["batch"]["state_registry"]["original_key"] + if("registry_" not in attr_dict): + data_registry = attr_dict["scvi_setup_dict_"]["categorical_mappings"] - if "unlabeled_category_" in attr_dict: - if attr_dict["unlabeled_category_"] is not None: - unlabeled_key_model = attr_dict["unlabeled_category_"] + cell_type_key_model = data_registry["_scvi_labels"]["original_key"] + condition_key_model = data_registry["_scvi_batch"]["original_key"] + else: + data_registry = attr_dict["registry_"]["field_registries"] + + cell_type_key_model = data_registry["labels"]["state_registry"]["original_key"] + condition_key_model = data_registry["batch"]["state_registry"]["original_key"] + + if "unlabeled_category_" in attr_dict: + if attr_dict["unlabeled_category_"] is not None: + unlabeled_key_model = attr_dict["unlabeled_category_"] + + else: + unlabeled_key_model = "unlabeled" + + + else: + attr_dict = utils.get_from_config(configuration, parameters.SCPOLI_ATTR) + cell_type_key_model = attr_dict["cell_type_keys_"][-1] + condition_key_model = attr_dict["condition_keys_"][-1] + unlabeled_key_model = "unlabeled" + + #Check if provided query contains respective labels + if cell_type_key_model not in target_adata.obs.columns or condition_key_model not in target_adata.obs.columns: + raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") return cell_type_key_model, condition_key_model, unlabeled_key_model diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index ed74587..2ba906e 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -65,7 +65,8 @@ def __init__(self, configuration) -> None: self._batch_key = None self._unlabeled_key = None - self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) + # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) + self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._query_adata, configuration) self._clf_native = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") From fd0a356fbd7b82143ad2b11c9e8db80c9d79b9f0 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 4 Jun 2024 14:00:38 +0200 Subject: [PATCH 29/63] remove hardcoding of get_keys --- mapping/process/processing.py | 24 ++++++------------------ mapping/user_upload/user_upload.py | 2 +- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index ff6b0d8..4f20fac 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -352,15 +352,9 @@ def bool_to_categorical(adata): - def _get_keys(target_adata, configuration): + def get_keys(target_adata, configuration): - # try: - # attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] - # except: - # if utils.get_from_config(configuration, parameters.MODEL) == "scANVI": - # sca.models.SCANVI.convert_legacy_save(model_path, model_path, True) - # if utils.get_from_config(configuration, parameters.MODEL) == "scVI": - # sca.models.SCVI.convert_legacy_save(model_path, model_path, True) + #Get model data registry and labels #Data management can be different among models, no clear indication in docs @@ -373,19 +367,13 @@ def _get_keys(target_adata, configuration): if model_type in ["scANVI","scVI"]: - model_path = utils.get_from_config(configuration, parameters.PRETRAINED_MODEL_PATH) + model_path = "." attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] - if("registry_" not in attr_dict): - data_registry = attr_dict["scvi_setup_dict_"]["categorical_mappings"] - - cell_type_key_model = data_registry["_scvi_labels"]["original_key"] - condition_key_model = data_registry["_scvi_batch"]["original_key"] - else: - data_registry = attr_dict["registry_"]["field_registries"] + data_registry = attr_dict["registry_"] - cell_type_key_model = data_registry["labels"]["state_registry"]["original_key"] - condition_key_model = data_registry["batch"]["state_registry"]["original_key"] + cell_type_key_model = data_registry["field_registries"]["labels"]["state_registry"]["original_key"] + condition_key_model = data_registry["field_registries"]["batch"]["state_registry"]["original_key"] if "unlabeled_category_" in attr_dict: if attr_dict["unlabeled_category_"] is not None: diff --git a/mapping/user_upload/user_upload.py b/mapping/user_upload/user_upload.py index 75aabbf..240b9e6 100644 --- a/mapping/user_upload/user_upload.py +++ b/mapping/user_upload/user_upload.py @@ -69,7 +69,7 @@ def __check_model_registry(self): def __check_atlas_labels(self, local_model_path): tuple = os.path.split(local_model_path) - cell_type_key, batch_key, unlabeled_key = Preprocess.get_keys_model(tuple[0]) + cell_type_key, batch_key, unlabeled_key = Preprocess.get_keys(tuple[0]) #Cell type keys only relevant to scANVI models if self.__model_type == "scANVI": From db96cec738ea48316bc9d427503f0cd16cd7af0a Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 4 Jun 2024 17:03:33 +0200 Subject: [PATCH 30/63] fix unknown key issue --- mapping/process/processing.py | 62 +++++++++++++++++++++++++++++----- mapping/scarches_api/models.py | 7 ++-- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 4f20fac..2a242dd 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -352,7 +352,7 @@ def bool_to_categorical(adata): - def get_keys(target_adata, configuration): + def get_keys(atlas, target_adata, configuration): @@ -363,6 +363,50 @@ def get_keys(target_adata, configuration): condition_key_model = None unlabeled_key_model = None + if atlas == 'pbmc': + cell_type_key = 'cell_type_for_integration' + batch_key = 'sample_ID_lataq' + elif atlas == 'heart': + cell_type_key = 'cell_type' + batch_key = 'donor' + elif atlas == 'hlca': + cell_type_key = 'scanvi_label' + batch_key = 'dataset' + elif atlas == 'hlca_retrained': + cell_type_key = 'ann_finest_level' + batch_key = 'sample' + elif atlas == 'retina': + cell_type_key = 'CellType' + batch_key = 'batch' + elif atlas == 'fetal_immune': + cell_type_key = 'celltype_annotation' + batch_key = 'bbk' + elif atlas == "nsclc": + cell_type_key = 'cell_type' + batch_key = 'sample' + elif atlas == "gb": + cell_type_key = 'CellID' + batch_key = 'author' + elif atlas == "hypomap": + cell_type_key = 'Author_CellType' + batch_key = 'Batch_ID' + elif atlas == "pancreas": + cell_type_key = "cell_type" + batch_key = "batch_integration" + elif atlas == "HRCA": + cell_type_key = "cell_type_scarches" + batch_key = "batch_donor_asset" + elif atlas == "hnoca": + cell_type_key = "snapseed_pca_rss_level_123" + batch_key = "batch" + elif atlas == "heoca": + cell_type_key = "cell_type" + batch_key = "sample_id" + elif atlas == "fetal_brain": + cell_type_key = "subregion_class" + batch_key = "batch" + + model_type = utils.get_from_config(configuration, parameters.MODEL) if model_type in ["scANVI","scVI"]: @@ -370,10 +414,10 @@ def get_keys(target_adata, configuration): model_path = "." attr_dict = _utils._load_saved_files(model_path, False, None, "cpu")[0] - data_registry = attr_dict["registry_"] + # data_registry = attr_dict["registry_"] - cell_type_key_model = data_registry["field_registries"]["labels"]["state_registry"]["original_key"] - condition_key_model = data_registry["field_registries"]["batch"]["state_registry"]["original_key"] + # cell_type_key_model = data_registry["field_registries"]["labels"]["state_registry"]["original_key"] + # condition_key_model = data_registry["field_registries"]["batch"]["state_registry"]["original_key"] if "unlabeled_category_" in attr_dict: if attr_dict["unlabeled_category_"] is not None: @@ -384,16 +428,16 @@ def get_keys(target_adata, configuration): else: - attr_dict = utils.get_from_config(configuration, parameters.SCPOLI_ATTR) - cell_type_key_model = attr_dict["cell_type_keys_"][-1] - condition_key_model = attr_dict["condition_keys_"][-1] + # attr_dict = utils.get_from_config(configuration, parameters.SCPOLI_ATTR) + # cell_type_key_model = attr_dict["cell_type_keys_"][-1] + # condition_key_model = attr_dict["condition_keys_"][-1] unlabeled_key_model = "unlabeled" #Check if provided query contains respective labels - if cell_type_key_model not in target_adata.obs.columns or condition_key_model not in target_adata.obs.columns: + if cell_type_key not in target_adata.obs.columns or batch_key not in target_adata.obs.columns: raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") - return cell_type_key_model, condition_key_model, unlabeled_key_model + return cell_type_key, batch_key, unlabeled_key_model def __get_keys_user(configuration): #Get parameters from user input diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 49845b1..e3719f0 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -67,7 +67,7 @@ def __init__(self, configuration) -> None: self._unlabeled_key = None # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) - self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._query_adata, configuration) + self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) self._clf_native = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") @@ -463,7 +463,7 @@ def _cleanup(self): class ScVI(ArchmapBaseModel): def _map_query(self): - #Align genes and gene order to model + #Align genes and gene order to model scarches.models.SCVI.prepare_query_anndata(self._query_adata, self._temp_model_path) #Setup adata internals for mapping @@ -496,6 +496,9 @@ def _compute_latent_representation(self, explicit_representation): class ScANVI(ArchmapBaseModel): def _map_query(self, supervised=False): #Align genes and gene order to model + if self._cell_type_key in self._query_adata.obs.columns: + self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] + self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) scarches.models.SCANVI.prepare_query_anndata(self._query_adata, self._temp_model_path) #Setup adata internals for mapping From 5f5715cf65f8ab3b3566a554d9ad906161784789 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 6 Jun 2024 10:51:38 +0200 Subject: [PATCH 31/63] change lr fetal brain --- mapping/scarches_api/models.py | 48 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index e3719f0..3ceb823 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -91,9 +91,13 @@ def _map_query(self): start_time = time.time() # threshold = 10000 + if self._atlas == "fetal_brain": + lr=0.01 + else: + lr=0.001 self._model.train( max_epochs=self._max_epochs, - plan_kwargs=dict(weight_decay=0.0), + plan_kwargs=dict(weight_decay=0.0,lr=lr), check_val_every_n_epoch=10, # datasplitter_kwargs = dict(distributed_sampler = True), # strategy='ddp_find_unused_parameters_true', @@ -117,17 +121,17 @@ def _map_query(self): # Calculate presence score - presence_score=estimate_presence_score( - self._reference_adata, - self._query_adata) + # presence_score=estimate_presence_score( + # self._reference_adata, + # self._query_adata) - self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + # self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - self.clust_pres_score=cluster_preservation_score(self._query_adata) - print(f"clust_pres_score: {self.clust_pres_score}") + # self.clust_pres_score=cluster_preservation_score(self._query_adata) + # print(f"clust_pres_score: {self.clust_pres_score}") - self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - print(f"query_with_anchor: {self.query_with_anchor}") + # self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + # print(f"query_with_anchor: {self.query_with_anchor}") def _acquire_data(self): @@ -152,7 +156,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - utils.notify_backend(self._webhook, {"ratio":ratio}) + # utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -213,7 +217,7 @@ def _transfer_labels(self): percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) - utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) def _concat_data(self): @@ -250,7 +254,7 @@ def _concat_data(self): self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - self._combined_adata.obs["presence_score"] = self.presence_score + # self._combined_adata.obs["presence_score"] = self.presence_score @@ -301,7 +305,7 @@ def _concat_data(self): print("read concatenated file") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - self._combined_adata.obs["presence_score"] = self.presence_score + # self._combined_adata.obs["presence_score"] = self.presence_score self._combined_adata.obs_names_make_unique() @@ -331,7 +335,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: @@ -588,17 +592,17 @@ def _map_query(self): - presence_score = estimate_presence_score( - self._reference_adata, - self._query_adata) + # presence_score = estimate_presence_score( + # self._reference_adata, + # self._query_adata) - self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + # self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - self.clust_pres_score=cluster_preservation_score(self._query_adata) - print(f"clust_pres_score: {self.clust_pres_score}") + # self.clust_pres_score=cluster_preservation_score(self._query_adata) + # print(f"clust_pres_score: {self.clust_pres_score}") - self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - print(f"query_with_anchor: {self.query_with_anchor}") + # self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + # print(f"query_with_anchor: {self.query_with_anchor}") From 4d6d209f369f610a824602ee2816ad17fbd00cfe Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 6 Jun 2024 10:53:13 +0200 Subject: [PATCH 32/63] uncomment --- mapping/scarches_api/models.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 3ceb823..979924a 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -121,17 +121,17 @@ def _map_query(self): # Calculate presence score - # presence_score=estimate_presence_score( - # self._reference_adata, - # self._query_adata) + presence_score=estimate_presence_score( + self._reference_adata, + self._query_adata) - # self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - # self.clust_pres_score=cluster_preservation_score(self._query_adata) - # print(f"clust_pres_score: {self.clust_pres_score}") + self.clust_pres_score=cluster_preservation_score(self._query_adata) + print(f"clust_pres_score: {self.clust_pres_score}") - # self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - # print(f"query_with_anchor: {self.query_with_anchor}") + self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + print(f"query_with_anchor: {self.query_with_anchor}") def _acquire_data(self): @@ -156,7 +156,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - # utils.notify_backend(self._webhook, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -217,7 +217,7 @@ def _transfer_labels(self): percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) - # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) def _concat_data(self): @@ -254,7 +254,7 @@ def _concat_data(self): self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - # self._combined_adata.obs["presence_score"] = self.presence_score + self._combined_adata.obs["presence_score"] = self.presence_score @@ -305,7 +305,7 @@ def _concat_data(self): print("read concatenated file") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - # self._combined_adata.obs["presence_score"] = self.presence_score + self._combined_adata.obs["presence_score"] = self.presence_score self._combined_adata.obs_names_make_unique() @@ -335,7 +335,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: @@ -592,17 +592,17 @@ def _map_query(self): - # presence_score = estimate_presence_score( - # self._reference_adata, - # self._query_adata) + presence_score = estimate_presence_score( + self._reference_adata, + self._query_adata) - # self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) + self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - # self.clust_pres_score=cluster_preservation_score(self._query_adata) - # print(f"clust_pres_score: {self.clust_pres_score}") + self.clust_pres_score=cluster_preservation_score(self._query_adata) + print(f"clust_pres_score: {self.clust_pres_score}") - # self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - # print(f"query_with_anchor: {self.query_with_anchor}") + self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) + print(f"query_with_anchor: {self.query_with_anchor}") From 0fe1fcb64e0bd3b8d8f6315079b44a044d56cc7d Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sun, 9 Jun 2024 18:02:35 +0200 Subject: [PATCH 33/63] increase lr for fetal brain --- mapping/scarches_api/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 979924a..f016a52 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -92,9 +92,11 @@ def _map_query(self): # threshold = 10000 if self._atlas == "fetal_brain": - lr=0.01 + lr=0.1 else: lr=0.001 + + print(f"lr: {lr}") self._model.train( max_epochs=self._max_epochs, plan_kwargs=dict(weight_decay=0.0,lr=lr), @@ -326,6 +328,8 @@ def _save_data(self): print("adding X from cloud") self.add_X_from_cloud() + print("add all genes") + combined_downsample = self.downsample_adata() #Save output From de6d6d424a4aa0fba0dc82e909fd84562fac1a44 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 17 Jun 2024 15:58:24 +0200 Subject: [PATCH 34/63] speed up metric calculations --- mapping/classifiers/classify.py | 80 +++++++++++++------- mapping/scarches_api/models.py | 80 ++++++++++---------- mapping/scarches_api/uncert/uncert_metric.py | 3 +- mapping/scarches_api/utils/metrics.py | 57 +++++--------- 4 files changed, 110 insertions(+), 110 deletions(-) diff --git a/mapping/classifiers/classify.py b/mapping/classifiers/classify.py index ee348df..4dbc91f 100644 --- a/mapping/classifiers/classify.py +++ b/mapping/classifiers/classify.py @@ -8,41 +8,65 @@ def main(atlas, label): adata=sc.read(f"data/{atlas}.h5ad") - reference_latent=sc.AnnData(adata.obsm["latent_rep"], adata.obs) + reference_latent=sc.AnnData(adata.obsm["X_latent_qzm"], adata.obs) - #create knn classifier - clf = Classifiers(False, True, None) - clf.create_classifier(reference_latent, False, "", label, f"models_new/{atlas}_{label}") + if isinstance(label, list): + for l in label: + #create knn classifier + clf = Classifiers(False, True, None) + clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{l}") - #create xgb classifier - clf = Classifiers(True, False, None) - clf.create_classifier(reference_latent, False, "", label, f"models_new/{atlas}_{label}") + #create xgb classifier + clf = Classifiers(True, False, None) + clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{l}") + + + else: + #create knn classifier + clf = Classifiers(False, True, None) + clf.create_classifier(reference_latent, True, "", label,f"models_new/{atlas}/{atlas}_{label}") + + #create xgb classifier + clf = Classifiers(True, False, None) + clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{label}") if __name__ == "__main__": - atlas_dict = {"hlca": "ann_level_4", - "hlca_retrained": "ann_finest_level", - "gb": "CellID", - "fetal_immune":"celltype_annotation", - "nsclc": "cell_type", - "retina": "CellType", - "pancreas_scpoli":"cell_type", - "pancreas_scanvi":"cell_type", - "pbmc":"cell_type_for_integration", - "hypomap": "Author_CellType", - } - atlas_dict_scpoli = { - "hlca_retrained": "ann_finest_level", - "pancreas_scpoli":"cell_type", - "pbmc":"cell_type_for_integration", - # "hnoca":"annot_level_2", - # "heoca": "cell_type" - } + # atlas_dict = {"fetal_brain": "subregion_class"} + + atlas_dict = {"hnoca": + ['annot_level_1', + 'annot_level_2', + 'annot_level_3_rev2', + 'annot_level_4_rev2', + 'annot_region_rev2', + 'annot_ntt_rev2',], + + "hnoca_ce": 'annot_level_2_ce'} + + # atlas_dict = {"hlca": "ann_level_4", + # "hlca_retrained": "ann_finest_level", + # "gb": "CellID", + # "fetal_immune":"celltype_annotation", + # "nsclc": "cell_type", + # "retina": "CellType", + # "pancreas_scpoli":"cell_type", + # "pancreas_scanvi":"cell_type", + # "pbmc":"cell_type_for_integration", + # "hypomap": "Author_CellType", + # } + # atlas_dict_scpoli = { + # "hlca_retrained": "ann_finest_level", + # "pancreas_scpoli":"cell_type", + # "pbmc":"cell_type_for_integration", + # # "hnoca":"annot_level_2", + # # "heoca": "cell_type" + # } for atlas, label in atlas_dict.items(): main(atlas, label) print(f"successfully created classifier for {atlas}") - for atlas, label in atlas_dict_scpoli.items(): - main(atlas, label, is_scpoli=True) - print(f"successfully created classifier for {atlas}") \ No newline at end of file + # for atlas, label in atlas_dict_scpoli.items(): + # main(atlas, label, is_scpoli=True) + # print(f"successfully created classifier for {atlas}") \ No newline at end of file diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index f016a52..1f31aae 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -16,7 +16,7 @@ from scvi.dataloaders import BatchDistributedSampler from utils import parameters -from utils.metrics import estimate_presence_score, cluster_preservation_score, build_mutual_nn, percent_query_with_anchor, stress_score +from utils.metrics import estimate_presence_score, cluster_preservation_score, percent_query_with_anchor, stress_score, get_wknn from utils.utils import get_from_config from utils.utils import fetch_file_from_s3 from utils.utils import read_h5ad_file_from_s3, get_file_size_in_gb, replace_X_on_disk @@ -79,16 +79,18 @@ def __init__(self, configuration) -> None: print(f"time {end_time-start_time}") def run(self): + start_time = time.time() self._map_query() self._eval_mapping() self._transfer_labels() self._concat_data() self._save_data() + end_time = time.time() + print(f"time {end_time-start_time}") self._cleanup() def _map_query(self): #Map the query onto reference - start_time = time.time() # threshold = 10000 if self._atlas == "fetal_brain": @@ -96,7 +98,6 @@ def _map_query(self): else: lr=0.001 - print(f"lr: {lr}") self._model.train( max_epochs=self._max_epochs, plan_kwargs=dict(weight_decay=0.0,lr=lr), @@ -107,9 +108,6 @@ def _map_query(self): # devices=4 ) - end_time = time.time() - print(f"time {end_time-start_time}") - if "X_latent_qzm" in self._reference_adata.obsm and "X_latent_qzv" in self._reference_adata.obsm: print("__________getting X_latent_qzm from minified atlas for scvi-tools models___________") qzm = self._reference_adata.obsm["X_latent_qzm"] @@ -121,19 +119,6 @@ def _map_query(self): #Save out the latent representation for QUERY self._compute_latent_representation(explicit_representation=self._query_adata) - # Calculate presence score - - presence_score=estimate_presence_score( - self._reference_adata, - self._query_adata) - - self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - - self.clust_pres_score=cluster_preservation_score(self._query_adata) - print(f"clust_pres_score: {self.clust_pres_score}") - - self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - print(f"query_with_anchor: {self.query_with_anchor}") def _acquire_data(self): @@ -180,7 +165,7 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) #stress score @@ -256,10 +241,6 @@ def _concat_data(self): self._combined_adata.obs=self._combined_adata.obs[list(new_columns)] self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - self._combined_adata.obs["presence_score"] = self.presence_score - - - del self._query_adata del self._reference_adata gc.collect() @@ -307,7 +288,6 @@ def _concat_data(self): print("read concatenated file") self._combined_adata.obsm["latent_rep"] = self.latent_full_from_mean_var - self._combined_adata.obs["presence_score"] = self.presence_score self._combined_adata.obs_names_make_unique() @@ -328,9 +308,43 @@ def _save_data(self): print("adding X from cloud") self.add_X_from_cloud() - print("add all genes") combined_downsample = self.downsample_adata() + + # Calculate presence score + + ref_downsample = combined_downsample[combined_downsample.obs["query"]=="0"] + query_downsample = combined_downsample[combined_downsample.obs["query"]=="1"] + + ref_latent_downsample = combined_downsample[combined_downsample.obs["query"]=="0"].obsm["latent_rep"] + query_latent_downsample = combined_downsample[combined_downsample.obs["query"]=="1"].obsm["latent_rep"] + + self.knn_ref = self.knn_ref_trainer.fit_transform(ref_latent_downsample) + + wknn, adjs = get_wknn( + ref=ref_latent_downsample, + query=query_latent_downsample, + k=15, + # adj_q2r=self.knn_q2r, + adj_ref=self.knn_ref, + return_adjs=True + ) + + presence_score = estimate_presence_score( + ref_downsample, + query_downsample, + wknn = wknn) + + + self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(query_downsample))) + + combined_downsample.obs["presence_score"] = self.presence_score + + self.clust_pres_score=cluster_preservation_score(query_downsample) + print(f"clust_pres_score: {self.clust_pres_score}") + + self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) + print(f"query_with_anchor: {self.query_with_anchor}") #Save output Postprocess.output(None, combined_downsample, self._configuration) @@ -593,20 +607,6 @@ def _map_query(self): else: self._compute_latent_representation(explicit_representation=self._reference_adata) self._compute_latent_representation(explicit_representation=self._query_adata) - - - - presence_score = estimate_presence_score( - self._reference_adata, - self._query_adata) - - self.presence_score = np.concatenate((presence_score["max"],[np.nan]*len(self._query_adata))) - - self.clust_pres_score=cluster_preservation_score(self._query_adata) - print(f"clust_pres_score: {self.clust_pres_score}") - - self.query_with_anchor=percent_query_with_anchor(self._reference_adata, self._query_adata) - print(f"query_with_anchor: {self.query_with_anchor}") diff --git a/mapping/scarches_api/uncert/uncert_metric.py b/mapping/scarches_api/uncert/uncert_metric.py index 7b9ec5b..6ee3b99 100644 --- a/mapping/scarches_api/uncert/uncert_metric.py +++ b/mapping/scarches_api/uncert/uncert_metric.py @@ -153,7 +153,6 @@ def classification_uncert_euclidean( cell_type_key, trainer ) - #Important to store as numpy array in obs for cellbygene visualization if(len(uncertainties.columns) > 1): for entry in uncertainties.columns: @@ -162,7 +161,7 @@ def classification_uncert_euclidean( else: adata_query_raw.obs[cell_type_key + '_uncertainty_euclidean'] = uncertainties.to_numpy(dtype="float32") - return uncertainties + return trainer # # Test differential abundance analysis on neighbourhoods with Milo. # def classification_uncert_milo( diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 95678a0..3fb9d5c 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -77,12 +77,10 @@ def build_mutual_nn(dat1, dat2=None, k1=15, k2=None): adj_21 = build_nn(dat2, dat1, k=k1) adj_mnn = adj_12.multiply(adj_21.T) - return adj_mnn + return adj_mnn, adj_12, adj_21 -def percent_query_with_anchor(ref_adata, query_adata): - ref = ref_adata.obsm["latent_rep"] - query = query_adata.obsm["latent_rep"] - adj_mnn=build_mutual_nn(ref,query) +def percent_query_with_anchor( adj_r2q, adj_q2r): + adj_mnn = adj_r2q.multiply(adj_q2r.T) has_anchor=adj_mnn.sum(0)>0 #all query cells that have an anchor (output dim: no query cells) percentage = (has_anchor.sum()/adj_mnn.shape[1])*100 return round(percentage, 2) @@ -109,10 +107,9 @@ def random_walk_with_restart(init, transition_prob, alpha=0.5, num_rounds=100): def get_wknn( ref, query, - ref2=None, k: int = 100, - query2ref: bool = True, - ref2query: bool = True, + adj_q2r=None, + adj_ref=None, weighting_scheme: Literal[ "n", "top_n", "jaccard", "jaccard_square", "gaussian", "dist" ] = "jaccard_square", @@ -143,28 +140,17 @@ def get_wknn( return_adjs : bool Whether to return the adjacency matrices of ref-query, query-ref, and ref-ref for weighting """ - adj_q2r = build_nn(ref=ref, query=query, k=k) - - adj_r2q = None - if ref2query: - adj_r2q = build_nn(ref=query, query=ref, k=k) - - if query2ref and not ref2query: - adj_knn = adj_q2r.T - elif ref2query and not query2ref: - adj_knn = adj_r2q - elif ref2query and query2ref: - adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 - else: - warnings.warn( - "At least one of query2ref and ref2query should be True. Reset to default with both being True." - ) - adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 # 1 if either R_i or Q_j are considered a nn of the other + if adj_q2r is None: + adj_q2r = build_nn(ref=ref, query=query, k=k) + + adj_r2q = build_nn(ref=query, query=ref, k=k) - if ref2 is None: - ref2 = ref - adj_ref = build_nn(ref=ref2, k=k) + adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 + adj_knn = ((adj_r2q + adj_q2r.T) > 0) + 0 # 1 if either R_i or Q_j are considered a nn of the other + + if adj_ref is None: + adj_ref = build_nn(ref=ref, k=k) num_shared_neighbors = adj_q2r @ adj_ref # no. neighbours that Q_i and R_j have in common num_shared_neighbors_nn = num_shared_neighbors.multiply(adj_knn.T) # only keep weights if q and r are both nearest neigbours of eachother @@ -182,7 +168,8 @@ def get_wknn( wknn.data = (wknn.data / (k + k - wknn.data)) ** 2 if return_adjs: - adjs = {"q2r": adj_q2r, "r2q": adj_r2q, "knn": adj_knn, "r2r": adj_ref} + adjs = {"q2r": adj_q2r, "r2q": adj_r2q} + # adjs = {"q2r": adj_q2r, "r2q": adj_r2q, "knn": adj_knn, "r2r": adj_ref} return (wknn, adjs) else: return wknn @@ -209,17 +196,7 @@ def estimate_presence_score( num_rounds_random_walk=100, log=True, ): - if wknn is None: - ref = ref_adata.obsm[use_rep_ref_wknn] - query = query_adata.obsm[use_rep_query_wknn] - wknn = get_wknn( - ref=ref, - query=query, - k=k_wknn, - query2ref=query2ref_wknn, - ref2query=ref2query_wknn, - weighting_scheme=weighting_scheme_wknn, - ) + if ref_trans_prop is None and do_random_walk: if use_rep_ref_trans_prop is None: From 156803c7ecf60432bc1af09ac0f0e15adfc31dde Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 17 Jun 2024 18:39:53 +0200 Subject: [PATCH 35/63] change max_epochs for fetal brain --- mapping/scarches_api/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 1f31aae..e3961fc 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -95,6 +95,7 @@ def _map_query(self): # threshold = 10000 if self._atlas == "fetal_brain": lr=0.1 + self._max_epochs = 20 else: lr=0.001 From ef155011e36b09ac420748b1a115d799020f88e7 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 18 Jun 2024 10:24:23 +0200 Subject: [PATCH 36/63] fix output error --- mapping/scarches_api/uncert/uncert_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/uncert/uncert_metric.py b/mapping/scarches_api/uncert/uncert_metric.py index 6ee3b99..896cbaa 100644 --- a/mapping/scarches_api/uncert/uncert_metric.py +++ b/mapping/scarches_api/uncert/uncert_metric.py @@ -146,7 +146,7 @@ def classification_uncert_euclidean( for col in cell_type_cols: adata_ref_latent.obs[col] = adata_ref_latent.obs[col].astype(str) - _, uncertainties = sca.utils.weighted_knn_transfer( + _, uncertainties, _ = sca.utils.weighted_knn_transfer( adata_query_latent, embedding_name, adata_ref_latent.obs, From f6669aeff8a937b802dc96460c1cca6b61e7b62e Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 18 Jun 2024 12:28:26 +0200 Subject: [PATCH 37/63] move notify backend metrics --- mapping/scarches_api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index e3961fc..e4569a5 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -203,9 +203,7 @@ def _transfer_labels(self): #Compute label transfer and save to respective .obs query_latent = scanpy.AnnData(self._query_adata.obsm["latent_rep"]) - percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) - - utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": percent_unknown}) + self.percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) def _concat_data(self): @@ -346,6 +344,8 @@ def _save_data(self): self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) print(f"query_with_anchor: {self.query_with_anchor}") + + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) #Save output Postprocess.output(None, combined_downsample, self._configuration) From 4451c67a6ad970844e6ad7f800239ada33fafdd7 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Wed, 19 Jun 2024 12:53:05 +0200 Subject: [PATCH 38/63] fix bug percent unknown --- mapping/scarches_api/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index e4569a5..94d1ad3 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -57,6 +57,7 @@ def __init__(self, configuration) -> None: self._query_adata = None self._reference_adata = None self._combined_adata = None + self.percent_unknown = "n/a" #Load and process required data self._acquire_data() From 7800d39a00272687f4c18be2bf0af71212af7ec3 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 20 Jun 2024 11:50:47 +0200 Subject: [PATCH 39/63] increase fetal brain epochs --- mapping/scarches_api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 94d1ad3..1df92c0 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -68,7 +68,7 @@ def __init__(self, configuration) -> None: self._unlabeled_key = None # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) - self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) + self._cell_type_key, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) self._clf_native = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") @@ -96,7 +96,7 @@ def _map_query(self): # threshold = 10000 if self._atlas == "fetal_brain": lr=0.1 - self._max_epochs = 20 + self._max_epochs = 40 else: lr=0.001 @@ -167,7 +167,7 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, self._cell_type_key_list, False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) #stress score From aeab4e47c00c87f43ce7050a8aae12e1129a463a Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 20 Jun 2024 13:28:21 +0200 Subject: [PATCH 40/63] add cell_type_key_list --- mapping/process/processing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 2a242dd..6d57d67 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -362,6 +362,7 @@ def get_keys(atlas, target_adata, configuration): cell_type_key_model = None condition_key_model = None unlabeled_key_model = None + cell_type_key_list = None if atlas == 'pbmc': cell_type_key = 'cell_type_for_integration' @@ -398,6 +399,12 @@ def get_keys(atlas, target_adata, configuration): batch_key = "batch_donor_asset" elif atlas == "hnoca": cell_type_key = "snapseed_pca_rss_level_123" + cell_type_key_list = ['annot_level_1', + 'annot_level_2', + 'annot_level_3_rev2', + 'annot_level_4_rev2', + 'annot_region_rev2', + 'annot_ntt_rev2',] batch_key = "batch" elif atlas == "heoca": cell_type_key = "cell_type" @@ -437,7 +444,7 @@ def get_keys(atlas, target_adata, configuration): if cell_type_key not in target_adata.obs.columns or batch_key not in target_adata.obs.columns: raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") - return cell_type_key, batch_key, unlabeled_key_model + return cell_type_key, cell_type_key_list, batch_key, unlabeled_key_model def __get_keys_user(configuration): #Get parameters from user input From c0b40f7bc4ff01113c28e88dbc734bf1624aa88a Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 20 Jun 2024 20:23:54 +0200 Subject: [PATCH 41/63] fix bug --- mapping/scarches_api/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 1df92c0..a616a7e 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -70,6 +70,10 @@ def __init__(self, configuration) -> None: # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) self._cell_type_key, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) + # if self._cell_type_key_list is None: + # self._cell_type_key_list = [self._cell_type_key] + + self._clf_native = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") self._clf_knn = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("kNN") @@ -167,7 +171,7 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, self._cell_type_key_list, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) #stress score From f57f30d0d8223ae3956cffa5d75f8e016d693405 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 20 Jun 2024 22:28:41 +0200 Subject: [PATCH 42/63] fix bug --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index a616a7e..213dd13 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -171,7 +171,7 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", False, False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) #stress score From d8a402795e25f11c1ddf0b0c088bed8c42fbcf5c Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Fri, 21 Jun 2024 10:01:52 +0200 Subject: [PATCH 43/63] add cell type key --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 213dd13..226a2ed 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -171,7 +171,7 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", False, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) #stress score From b50940f0caa80a86f4c07cbf9d65d513151dbbcd Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:58:48 +0200 Subject: [PATCH 44/63] fix duplicate bug --- mapping/scarches_api/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 226a2ed..db998ae 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -491,7 +491,8 @@ def _cleanup(self): class ScVI(ArchmapBaseModel): def _map_query(self): - #Align genes and gene order to model + #Align genes and gene order to model + self._query_adata.var_names_make_unique() scarches.models.SCVI.prepare_query_anndata(self._query_adata, self._temp_model_path) #Setup adata internals for mapping @@ -527,6 +528,9 @@ def _map_query(self, supervised=False): if self._cell_type_key in self._query_adata.obs.columns: self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + + self._query_adata.var_names_make_unique() + scarches.models.SCANVI.prepare_query_anndata(self._query_adata, self._temp_model_path) #Setup adata internals for mapping From 82e8e311a1a3e02f92e3c344b8c2607727fbf7d8 Mon Sep 17 00:00:00 2001 From: chelseabright96 <82758782+chelseabright96@users.noreply.github.com> Date: Sat, 22 Jun 2024 08:24:30 +0200 Subject: [PATCH 45/63] add scanvi pred to ref --- mapping/scarches_api/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index db998ae..3420424 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -259,6 +259,7 @@ def _concat_data(self): self._reference_adata.obs['uncertainty_mahalanobis'] = pandas.Series(dtype="float32") self._reference_adata.obs['prediction_xgb'] = pandas.Series(dtype="category") self._reference_adata.obs['prediction_knn'] = pandas.Series(dtype="category") + self._reference_adata.obs['prediction_scanvi'] = pandas.Series(dtype="category") #Create temp files on disk temp_reference = tempfile.NamedTemporaryFile(suffix=".h5ad") From 3e7b5297cc86263f77cb179585102555d4862c06 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Sat, 22 Jun 2024 09:07:06 +0200 Subject: [PATCH 46/63] add classifiers for cell type list --- mapping/classifiers/classifiers.py | 20 +- mapping/classifiers/classify.py | 16 +- mapping/classifiers/uncert_train.py | 78 ++++++++ mapping/scarches_api/models.py | 98 ++++++---- mapping/scarches_api/uncert/uncert_metric.py | 186 ++++++++++--------- mapping/user_upload/user_upload.py | 20 +- 6 files changed, 276 insertions(+), 142 deletions(-) create mode 100644 mapping/classifiers/uncert_train.py diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index 14248af..ae5e434 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -51,7 +51,7 @@ def __init__(self, classifier_xgb=False, classifier_knn=False, classifier_native query: adata to save labels to query_latent: adata to read .X from for label prediction ''' - def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), classifier_path="/path/to/classifier", encoding_path="/path/to/encoding"): + def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), classifier_path="/path/to/classifier", encoding_path="/path/to/encoding", cell_type_key=None): le = LabelEncoder() if self.__classifier_xgb: @@ -61,8 +61,8 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), xgb_model = XGBClassifier() xgb_model.load_model(classifier_path) - query.obs["prediction_xgb"] = le.inverse_transform(xgb_model.predict(query_latent.X)) - prediction_label = "prediction_xgb" + query.obs[f"{cell_type_key}_prediction_xgb"] = le.inverse_transform(xgb_model.predict(query_latent.X)) + prediction_label = f"{cell_type_key}_prediction_xgb" if self.__classifier_knn: with open(encoding_path, "rb") as file: @@ -71,18 +71,18 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), with open(classifier_path, "rb") as file: knn_model = pickle.load(file) - query.obs["prediction_knn"] = le.inverse_transform(knn_model.predict(query_latent.X)) - prediction_label = "prediction_knn" + query.obs[f"{cell_type_key}_prediction_knn"] = le.inverse_transform(knn_model.predict(query_latent.X)) + prediction_label = f"{cell_type_key}_prediction_knn" if self.__classifier_native is not None: if "SCANVI" in str(self.__model_class): - query.obs["prediction_scanvi"] = self.__classifier_native.predict(query) - prediction_label = "prediction_scanvi" + query.obs[f"{cell_type_key}_prediction_scanvi"] = self.__classifier_native.predict(query) + prediction_label = f"{cell_type_key}_prediction_scanvi" else: output=self.__classifier_native.classify(query, scale_uncertainties=True) - query.obs["prediction_scpoli"] = list(output.values())[0]["preds"] - query.obs["uncertainty_scpoli"] = list(output.values())[0]["uncert"] - prediction_label = "prediction_scpoli" + query.obs[f"{cell_type_key}_prediction_scpoli"] = list(output.values())[0]["preds"] + query.obs[f"{cell_type_key}_uncertainty_scpoli"] = list(output.values())[0]["uncert"] + prediction_label = f"{cell_type_key}_prediction_scpoli" # calculate the percentage of unknown cell types (cell types with uncertainty higher than 0.5) percent_unknown = percentage_unknown(query, prediction_label) diff --git a/mapping/classifiers/classify.py b/mapping/classifiers/classify.py index 4dbc91f..6645c0d 100644 --- a/mapping/classifiers/classify.py +++ b/mapping/classifiers/classify.py @@ -4,21 +4,25 @@ from classifiers import Classifiers -def main(atlas, label): +def main(atlas, label, is_scpoli=False): adata=sc.read(f"data/{atlas}.h5ad") - reference_latent=sc.AnnData(adata.obsm["X_latent_qzm"], adata.obs) + if is_scpoli: + reference_latent=sc.AnnData(adata.obsm["X_latent_qzm_scpoli"], adata.obs) + + else: + reference_latent=sc.AnnData(adata.obsm["X_latent_qzm"], adata.obs) if isinstance(label, list): for l in label: #create knn classifier clf = Classifiers(False, True, None) - clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{l}") + clf.create_classifier(reference_latent, True, "", l, f"models_new/{atlas}/{atlas}_{l}") #create xgb classifier clf = Classifiers(True, False, None) - clf.create_classifier(reference_latent, True, "", label, f"models_new/{atlas}/{atlas}_{l}") + clf.create_classifier(reference_latent, True, "", l, f"models_new/{atlas}/{atlas}_{l}") else: @@ -35,7 +39,7 @@ def main(atlas, label): # atlas_dict = {"fetal_brain": "subregion_class"} - atlas_dict = {"hnoca": + atlas_dict = {"hnoca_new": ['annot_level_1', 'annot_level_2', 'annot_level_3_rev2', @@ -65,7 +69,7 @@ def main(atlas, label): # } for atlas, label in atlas_dict.items(): - main(atlas, label) + main(atlas, label, is_scpoli=True) print(f"successfully created classifier for {atlas}") # for atlas, label in atlas_dict_scpoli.items(): # main(atlas, label, is_scpoli=True) diff --git a/mapping/classifiers/uncert_train.py b/mapping/classifiers/uncert_train.py new file mode 100644 index 0000000..7b09675 --- /dev/null +++ b/mapping/classifiers/uncert_train.py @@ -0,0 +1,78 @@ +import scanpy as sc +from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import scarches as sca +import pickle + + +from sklearn.mixture import GaussianMixture + +def train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key, pretrained=True): + + + num_clusters = adata_ref.obs[cell_type_key].nunique() + print(num_clusters) + + train_emb = adata_ref.obsm[embedding_name] + + #Required too much RAM + gmm = GaussianMixture(n_components=num_clusters) + gmm.fit(train_emb) + + #Less RAM alternative + # kmeans = KMeans(n_clusters=num_clusters) + # kmeans.fit(train_emb) + + #Save or return model + if pretrained: + with open("models_uncert/" + atlas + "/" + cell_type_key + "_mahalanobis_distance.pickle", "wb") as file: + pickle.dump(gmm, file, pickle.HIGHEST_PROTOCOL) + else: + return gmm + +def train_euclidian(atlas, adata_ref, embedding_name, pretrained =True, n_neighbors = 15): + + trainer = sca.utils.weighted_knn_trainer( + adata_ref, + embedding_name, + n_neighbors = n_neighbors + ) + + #Save model + if pretrained: + with open("models_uncert/" + atlas + "/" + "euclidian_distance.pickle", "wb") as file: + pickle.dump(trainer, file, pickle.HIGHEST_PROTOCOL) + else: + return trainer + + +def main(atlas, adata_ref, cell_type_key_list=None, is_scpoli = False): + + if is_scpoli: + embedding_name = "X_latent_qzm_scpoli" + else: + embedding_name = "X_latent_qzm" + + train_euclidian(atlas, adata_ref, embedding_name) + + for cell_type_key in cell_type_key_list: + print(cell_type_key) + train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key) + + + +if __name__ == "__main__": + + atlas = "hnoca_new" + adata_ref = sc.read(f"data/{atlas}.h5ad") + cell_type_key_list = ['annot_level_1', + 'annot_level_2', + 'annot_level_3_rev2', + 'annot_level_4_rev2', + 'annot_region_rev2', + 'annot_ntt_rev2',] + + main(atlas, adata_ref, cell_type_key_list, is_scpoli = True) \ No newline at end of file diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 226a2ed..ae8e0ed 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -70,16 +70,14 @@ def __init__(self, configuration) -> None: # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) self._cell_type_key, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) - # if self._cell_type_key_list is None: - # self._cell_type_key_list = [self._cell_type_key] + if self._cell_type_key_list is None: + self._cell_type_key_list = [self._cell_type_key] - self._clf_native = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") - self._clf_xgb = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") - self._clf_knn = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_TYPE).pop("kNN") - self._clf_model_path = get_from_config(configuration=configuration, key=parameters.CLASSIFIER_PATH) - self._clf_encoding_path = get_from_config(configuration=configuration, key=parameters.ENCODING_PATH) - + self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") + self._clf_xgb = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") + self._clf_knn = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("kNN") + end_time = time.time() print(f"time {end_time-start_time}") @@ -100,7 +98,7 @@ def _map_query(self): # threshold = 10000 if self._atlas == "fetal_brain": lr=0.1 - self._max_epochs = 40 + # self._max_epochs = 40 else: lr=0.001 @@ -149,7 +147,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - utils.notify_backend(self._webhook, {"ratio":ratio}) + # utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -171,8 +169,8 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) - classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key_list, False) + classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key_list, False) #stress score if self._atlas=="hnoca": @@ -183,33 +181,53 @@ def _eval_mapping(self): def _transfer_labels(self): if not self._clf_native and not self._clf_knn and not self._clf_xgb: return + + #Compute label transfer and save to respective .obs + query_latent = scanpy.AnnData(self._query_adata.obsm["latent_rep"]) + if self._clf_native: clf = Classifiers(self._clf_xgb, self._clf_knn, self._model, self._model.__class__) + for cell_type_key in self._cell_type_key: + self.percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path, cell_type_key) + + #Instantiate xgb or knn classifier if selected if self._clf_xgb or self._clf_knn: clf = Classifiers(self._clf_xgb, self._clf_knn, None, self._model.__class__) + self._clf_path = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_PATH) - #Download classifiers and encoding from GCP if kNN or XGBoost - if self._clf_xgb: - self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") - fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) + for cell_type_key in self._cell_type_key: - self._temp_clf_model_path = tempfile.mktemp(suffix=".ubj") - fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) - elif self._clf_knn: - self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") - fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) + self._clf_encoding_path = self._clf_path + cell_type_key + "/classifier_encoding.pickle" - self._temp_clf_model_path = tempfile.mktemp(suffix=".pickle") - fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) + #Download classifiers and encoding from GCP if kNN or XGBoost + if self._clf_xgb: + self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") + fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) - #Compute label transfer and save to respective .obs - query_latent = scanpy.AnnData(self._query_adata.obsm["latent_rep"]) - - self.percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path) + self._temp_clf_model_path = tempfile.mktemp(suffix=".ubj") + self._clf_model_path = self._clf_path + cell_type_key + "/classifier_xgb.ubj" + fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) + + elif self._clf_knn: + self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") + fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) + self._clf_model_path = self._clf_path + cell_type_key + "/classifier_knn.pickle" + self._temp_clf_model_path = tempfile.mktemp(suffix=".pickle") + fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) + + self.percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path, cell_type_key) + + # remove temp files + if self._temp_clf_model_path is not None: + if os.path.exists(self._temp_clf_model_path): + os.remove(self._temp_clf_model_path) + if self._temp_clf_encoding_path is not None: + if os.path.exists(self._temp_clf_encoding_path): + os.remove(self._temp_clf_encoding_path) def _concat_data(self): @@ -255,10 +273,12 @@ def _concat_data(self): print("concatenating on disk") #Added because concat_on_disk only allows inner joins - self._reference_adata.obs[self._cell_type_key + '_uncertainty_euclidean'] = pandas.Series(dtype="float32") - self._reference_adata.obs['uncertainty_mahalanobis'] = pandas.Series(dtype="float32") - self._reference_adata.obs['prediction_xgb'] = pandas.Series(dtype="category") - self._reference_adata.obs['prediction_knn'] = pandas.Series(dtype="category") + for cell_type_key in self._cell_type_key_list: + self._reference_adata.obs[cell_type_key + '_uncertainty_euclidean'] = pandas.Series(dtype="float32") + self._reference_adata.obs[cell_type_key + '_uncertainty_mahalanobis'] = pandas.Series(dtype="float32") + self._reference_adata.obs[cell_type_key + 'prediction_xgb'] = pandas.Series(dtype="category") + self._reference_adata.obs[cell_type_key + 'prediction_knn'] = pandas.Series(dtype="category") + self._reference_adata.obs[cell_type_key + "_prediction_scanvi"] = pandas.Series(dtype="category") #Create temp files on disk temp_reference = tempfile.NamedTemporaryFile(suffix=".h5ad") @@ -309,11 +329,15 @@ def _compute_latent_representation(self, explicit_representation): def _save_data(self): # add .X to self._combined_adata + print(self._combined_adata.obs.columns) + print("adding X from cloud") self.add_X_from_cloud() + print(self._combined_adata.obs.columns) combined_downsample = self.downsample_adata() + print(self._combined_adata.obs.columns) # Calculate presence score @@ -350,7 +374,7 @@ def _save_data(self): self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) print(f"query_with_anchor: {self.query_with_anchor}") - utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) + # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) #Save output Postprocess.output(None, combined_downsample, self._configuration) @@ -359,7 +383,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: @@ -482,12 +506,6 @@ def _cleanup(self): if os.path.exists(os.path.join(self._temp_model_path, "var_names.csv")): os.remove(os.path.join(self._temp_model_path, "var_names.csv")) - if self._temp_clf_model_path is not None: - if os.path.exists(self._temp_clf_model_path): - os.remove(self._temp_clf_model_path) - if self._temp_clf_encoding_path is not None: - if os.path.exists(self._temp_clf_encoding_path): - os.remove(self._temp_clf_encoding_path) class ScVI(ArchmapBaseModel): def _map_query(self): @@ -527,6 +545,8 @@ def _map_query(self, supervised=False): if self._cell_type_key in self._query_adata.obs.columns: self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + + self._query_adata.var_names_make_unique() scarches.models.SCANVI.prepare_query_anndata(self._query_adata, self._temp_model_path) #Setup adata internals for mapping diff --git a/mapping/scarches_api/uncert/uncert_metric.py b/mapping/scarches_api/uncert/uncert_metric.py index 896cbaa..40b6c45 100644 --- a/mapping/scarches_api/uncert/uncert_metric.py +++ b/mapping/scarches_api/uncert/uncert_metric.py @@ -10,6 +10,7 @@ #import milopy # import pertpy as pt from matplotlib.lines import Line2D +from utils.utils import fetch_file_from_s3 import pickle @@ -33,82 +34,91 @@ def mahalanobis(v, data): vector[centroid_index] = mahal return vector -def classification_uncert_mahalanobis( - configuration, - adata_ref_latent, - adata_query_latent, - adata_query_raw, - embedding_name, - cell_type_key, - pretrained - ): - """ Computes classification uncertainty, based on the Mahalanobis distance of each cell - to the cell cluster centroids - - Args: - adata_ref_latent (AnnData): Latent representation of the reference - adata_query_latent (AnnData): Latent representation of the query - - Returns: - uncertainties (pandas DataFrame): Classification uncertainties for all of the query cell types - """ - - #Load model - if pretrained: - atlas = get_from_config(configuration, utils.parameters.ATLAS) - - with open("/models/" + atlas + "/" + "mahalanobis_distance.pickle", "rb") as file: - kmeans = pickle.load(file) - else: - kmeans = train_mahalanobis(None, adata_ref_latent, embedding_name, cell_type_key, pretrained) +# def classification_uncert_mahalanobis( +# configuration, +# adata_ref_latent, +# adata_query_latent, +# adata_query_raw, +# embedding_name, +# cell_type_key, +# pretrained +# ): +# """ Computes classification uncertainty, based on the Mahalanobis distance of each cell +# to the cell cluster centroids + +# Args: +# adata_ref_latent (AnnData): Latent representation of the reference +# adata_query_latent (AnnData): Latent representation of the query + +# Returns: +# uncertainties (pandas DataFrame): Classification uncertainties for all of the query cell types +# """ + +# #Load model +# if pretrained: +# atlas = get_from_config(configuration, utils.parameters.ATLAS) + +# with open("/models/" + atlas + "/" + "mahalanobis_distance.pickle", "rb") as file: +# kmeans = pickle.load(file) +# else: +# kmeans = train_mahalanobis(None, adata_ref_latent, embedding_name, cell_type_key, pretrained) - uncertainties = pd.DataFrame(columns=["uncertainty"], index=adata_query_raw.obs_names) - centroids = kmeans.cluster_centers_ - adata_query = kmeans.transform(adata_query_latent.X) - - for query_cell_index in range(len(adata_query)): - query_cell = adata_query_latent.X[query_cell_index] - distance = mahalanobis(query_cell, centroids) - uncertainties.iloc[query_cell_index]['uncertainty'] = np.mean(distance) +# uncertainties = pd.DataFrame(columns=["uncertainty"], index=adata_query_raw.obs_names) +# centroids = kmeans.cluster_centers_ +# adata_query = kmeans.transform(adata_query_latent.X) + +# for query_cell_index in range(len(adata_query)): +# query_cell = adata_query_latent.X[query_cell_index] +# distance = mahalanobis(query_cell, centroids) +# uncertainties.iloc[query_cell_index]['uncertainty'] = np.mean(distance) - max_distance = np.max(uncertainties["uncertainty"]) - min_distance = np.min(uncertainties["uncertainty"]) - uncertainties["uncertainty"] = (uncertainties["uncertainty"] - min_distance) / (max_distance - min_distance + 1e-8) - adata_query_raw.obs['uncertainty_mahalanobis'] = uncertainties.to_numpy(dtype="float32") +# max_distance = np.max(uncertainties["uncertainty"]) +# min_distance = np.min(uncertainties["uncertainty"]) +# uncertainties["uncertainty"] = (uncertainties["uncertainty"] - min_distance) / (max_distance - min_distance + 1e-8) +# adata_query_raw.obs['uncertainty_mahalanobis'] = uncertainties.to_numpy(dtype="float32") - return uncertainties, centroids +# return uncertainties, centroids -def classification_uncert_mahalanobis2( +def classification_uncert_mahalanobis( configuration, adata_ref_latent, adata_query_latent, + adata_query_raw, embedding_name, - cell_type_key, - pretrained + cell_type_key_list, + pretrained=True ): #Load model - if pretrained: - atlas = get_from_config(configuration, utils.parameters.ATLAS) + for cell_type_key in cell_type_key_list: - with open("/models/" + atlas + "/" + "mahalanobis_distance.pickle", "rb") as file: - gmm = pickle.load(file) - else: - gmm = train_mahalanobis(None, adata_ref_latent, cell_type_key, pretrained) - - centroids = gmm.means_ - cluster_membership = gmm.predict_proba(adata_query_latent[embedding_name]) - - uncertainties = pd.DataFrame(columns=["uncertainty"], index=adata_query_latent.obs_names) - for query_cell_index, query_cell in enumerate(adata_query_latent.X): - distance = mahalanobis(query_cell, centroids) - weighed_distance = np.multiply(cluster_membership[query_cell_index], distance) - uncertainties.iloc[query_cell_index]['uncertainty'] = np.mean(weighed_distance) + if pretrained: + atlas = get_from_config(configuration, utils.parameters.ATLAS) + cloud_model_path = get_from_config(configuration, utils.parameters.PRETRAINED_MODEL_PATH)[:-len("model.pt")] + + uncert_model_path = "/models/" + atlas + "/uncertainty/" + cell_type_key + "_mahalanobis_distance.pickle" + fetch_file_from_s3(cloud_model_path, uncert_model_path) + + with open(uncert_model_path, "rb") as file: + gmm = pickle.load(file) + else: + gmm = train_mahalanobis(None, adata_ref_latent, cell_type_key, pretrained) - max_distance = np.max(uncertainties["uncertainty"]) - min_distance = np.min(uncertainties["uncertainty"]) - uncertainties["uncertainty"] = (uncertainties["uncertainty"] - min_distance) / (max_distance - min_distance + 1e-8) - adata_query_latent.obsm['uncertainty_mahalanobis'] = uncertainties - return uncertainties, centroids + centroids = gmm.means_ + cluster_membership = gmm.predict_proba(adata_query_latent[embedding_name]) + + uncertainties = pd.DataFrame(columns=["uncertainty"], index=adata_query_latent.obs_names) + for query_cell_index, query_cell in enumerate(adata_query_latent.X): + distance = mahalanobis(query_cell, centroids) + weighed_distance = np.multiply(cluster_membership[query_cell_index], distance) + uncertainties.iloc[query_cell_index]['uncertainty'] = np.mean(weighed_distance) + + max_distance = np.max(uncertainties["uncertainty"]) + min_distance = np.min(uncertainties["uncertainty"]) + uncertainties["uncertainty"] = (uncertainties["uncertainty"] - min_distance) / (max_distance - min_distance + 1e-8) + + adata_query_raw.obs[f"{cell_type_key}_uncertainty_mahalanobis"] = uncertainties.to_numpy(dtype="float32") + + # return uncertainties, centroids def classification_uncert_euclidean( configuration, @@ -116,8 +126,8 @@ def classification_uncert_euclidean( adata_query_latent, adata_query_raw, embedding_name, - cell_type_key, - pretrained + cell_type_key_list=None, + pretrained=True ): """Computes classification uncertainty, based on the Euclidean distance of each cell to its k-nearest neighbors. Additional adjustment by a Gaussian kernel is made @@ -132,34 +142,40 @@ def classification_uncert_euclidean( uncertainties (pandas DataFrame): Classification uncertainties for all of query cell_types """ + #Load model if pretrained: atlas = get_from_config(configuration, utils.parameters.ATLAS) + cloud_model_path = get_from_config(configuration, utils.parameters.PRETRAINED_MODEL_PATH)[:-len("model.pt")] - with open("/models/" + atlas + "/" + "euclidian_distance.pickle", "rb") as file: + uncert_model_path = "/models/" + atlas + "/euclidian_distance.pickle" + fetch_file_from_s3(cloud_model_path, uncert_model_path) + + with open(uncert_model_path, "rb") as file: trainer = pickle.load(file) else: trainer = train_euclidian(None, adata_ref_latent, embedding_name, pretrained) #Make sure cell_type is all strings (comparison with NaN and string doesnt work) - cell_type_cols = adata_ref_latent.obs.columns[adata_ref_latent.obs.columns.str.startswith(cell_type_key)] - for col in cell_type_cols: - adata_ref_latent.obs[col] = adata_ref_latent.obs[col].astype(str) - - _, uncertainties, _ = sca.utils.weighted_knn_transfer( - adata_query_latent, - embedding_name, - adata_ref_latent.obs, - cell_type_key, - trainer - ) - #Important to store as numpy array in obs for cellbygene visualization - if(len(uncertainties.columns) > 1): - for entry in uncertainties.columns: - name = str(entry + '_uncertainty_euclidean') - adata_query_raw.obs[name] = uncertainties[entry].to_numpy(dtype="float32") - else: - adata_query_raw.obs[cell_type_key + '_uncertainty_euclidean'] = uncertainties.to_numpy(dtype="float32") + for cell_type_key in cell_type_key_list: + cell_type_cols = adata_ref_latent.obs.columns[adata_ref_latent.obs.columns.str.startswith(cell_type_key)] + for col in cell_type_cols: + adata_ref_latent.obs[col] = adata_ref_latent.obs[col].astype(str) + + _, uncertainties = sca.utils.weighted_knn_transfer( + adata_query_latent, + embedding_name, + adata_ref_latent.obs, + cell_type_key, + trainer + ) + #Important to store as numpy array in obs for cellbygene visualization + if(len(uncertainties.columns) > 1): + for entry in uncertainties.columns: + name = str(entry + '_uncertainty_euclidean') + adata_query_raw.obs[name] = uncertainties[entry].to_numpy(dtype="float32") + else: + adata_query_raw.obs[cell_type_key + '_uncertainty_euclidean'] = uncertainties.to_numpy(dtype="float32") return trainer diff --git a/mapping/user_upload/user_upload.py b/mapping/user_upload/user_upload.py index 240b9e6..f101d06 100644 --- a/mapping/user_upload/user_upload.py +++ b/mapping/user_upload/user_upload.py @@ -51,6 +51,8 @@ def check_upload(self): local_model_path = self.__load_file_tmp(self.__model_path, "model.pt") local_reference_data_path = self.__load_file_tmp(self.__reference_data, "reference.h5ad") + self.__minify_adata() + self.__check_model_version() self.__check_model_registry() @@ -59,6 +61,10 @@ def check_upload(self): self.__check_classifier() + self.__train_weighted_knn() + + self.__train_gaussian_mixture() + self.__share_results() def __check_model_version(self): @@ -93,10 +99,10 @@ def __check_atlas_labels(self, local_model_path): def __check_atlas_genes(self, local_model_path, local_reference_data_path): tuple = os.path.split(local_model_path) var_names = _utils._load_saved_files(tuple[0], False, None, "cpu")[1] - reference_data = scanpy.read_h5ad(local_reference_data_path, backed="r") + self.reference_data = scanpy.read_h5ad(local_reference_data_path, backed="r") try: - reference_data_sub = reference_data[:,var_names] + reference_data_sub = self.reference_data[:,var_names] except: #raise ValueError("var_names from reference are different to the model") self.__result["errors"]["atlas"].append("var_names from reference are different to the model") @@ -105,6 +111,16 @@ def __check_atlas_genes(self, local_model_path, local_reference_data_path): def __check_classifier(self): pass + def __minify_adata(self): + pass + + def __train_weighted_knn(self): + pass + + def __train_gaussian_mixture(self): + pass + + def __share_results(self): webhook = utils.get_from_config(self.__configuration, parameters.WEBHOOK) From 8f969ff1dd0f820f59114bd99adfd21f76c6cd3f Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 25 Jun 2024 14:00:53 +0200 Subject: [PATCH 47/63] add multi label classification and pretrain uncert --- mapping/classifiers/classifiers.py | 3 +- mapping/classifiers/uncert_train.py | 56 ++++++++++++++++---- mapping/process/processing.py | 2 +- mapping/scarches_api/models.py | 50 ++++++++++------- mapping/scarches_api/uncert/uncert_metric.py | 26 +++++---- mapping/scarches_api/utils/metrics.py | 6 +-- mapping/scarches_api/utils/parameters.py | 1 + 7 files changed, 97 insertions(+), 47 deletions(-) diff --git a/mapping/classifiers/classifiers.py b/mapping/classifiers/classifiers.py index ae5e434..dd82007 100644 --- a/mapping/classifiers/classifiers.py +++ b/mapping/classifiers/classifiers.py @@ -85,7 +85,8 @@ def predict_labels(self, query=scanpy.AnnData(), query_latent=scanpy.AnnData(), prediction_label = f"{cell_type_key}_prediction_scpoli" # calculate the percentage of unknown cell types (cell types with uncertainty higher than 0.5) - percent_unknown = percentage_unknown(query, prediction_label) + percent_unknown = percentage_unknown(query, cell_type_key, prediction_label) + percent_unknown = 0.0 print(percent_unknown) return round(percent_unknown, 2) diff --git a/mapping/classifiers/uncert_train.py b/mapping/classifiers/uncert_train.py index 7b09675..0cdd2a5 100644 --- a/mapping/classifiers/uncert_train.py +++ b/mapping/classifiers/uncert_train.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import scarches as sca import pickle +import os from sklearn.mixture import GaussianMixture @@ -28,6 +29,7 @@ def train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key, pretraine #Save or return model if pretrained: + with open("models_uncert/" + atlas + "/" + cell_type_key + "_mahalanobis_distance.pickle", "wb") as file: pickle.dump(gmm, file, pickle.HIGHEST_PROTOCOL) else: @@ -58,6 +60,9 @@ def main(atlas, adata_ref, cell_type_key_list=None, is_scpoli = False): train_euclidian(atlas, adata_ref, embedding_name) + if isinstance(cell_type_key_list,str): + cell_type_key_list = [cell_type_key_list] + for cell_type_key in cell_type_key_list: print(cell_type_key) train_mahalanobis(atlas, adata_ref, embedding_name, cell_type_key) @@ -66,13 +71,46 @@ def main(atlas, adata_ref, cell_type_key_list=None, is_scpoli = False): if __name__ == "__main__": - atlas = "hnoca_new" - adata_ref = sc.read(f"data/{atlas}.h5ad") - cell_type_key_list = ['annot_level_1', - 'annot_level_2', - 'annot_level_3_rev2', - 'annot_level_4_rev2', - 'annot_region_rev2', - 'annot_ntt_rev2',] - main(atlas, adata_ref, cell_type_key_list, is_scpoli = True) \ No newline at end of file + + # cell_type_key_list = ['annot_level_1', + # 'annot_level_2', + # 'annot_level_3_rev2', + # 'annot_level_4_rev2', + # 'annot_region_rev2', + # 'annot_ntt_rev2',] + + atlas_dict = {"hlca": "scanvi_label", + #"gb": "CellID", + # "fetal_immune":"celltype_annotation", + # "nsclc": "cell_type", + # "retina": "CellType", + # "pancreas_scanvi":"cell_type", + # "hypomap": "Author_CellType", + } + atlas_dict_scpoli = { + # "hlca_retrained": "ann_finest_level", + # "pancreas_scpoli":"cell_type", + # "pbmc":"cell_type_for_integration", + # "hnoca":"annot_level_2", + # "heoca": "cell_type" + + } + + for atlas, cell_type_key_list in atlas_dict_scpoli.items(): + + directory = "models_uncert/" + atlas + "/" + if not os.path.exists(directory): + os.makedirs(directory) + + adata_ref = sc.read(f"data/{atlas}.h5ad") + main(atlas, adata_ref, cell_type_key_list, is_scpoli = True) + + for atlas, cell_type_key_list in atlas_dict.items(): + + directory = "models_uncert/" + atlas + "/" + if not os.path.exists(directory): + os.makedirs(directory) + + adata_ref = sc.read(f"data/{atlas}.h5ad") + main(atlas, adata_ref, cell_type_key_list, is_scpoli = False) \ No newline at end of file diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 6d57d67..f598371 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -398,7 +398,7 @@ def get_keys(atlas, target_adata, configuration): cell_type_key = "cell_type_scarches" batch_key = "batch_donor_asset" elif atlas == "hnoca": - cell_type_key = "snapseed_pca_rss_level_123" + cell_type_key ="annot_level_2" #"snapseed_pca_rss_level_123" cell_type_key_list = ['annot_level_1', 'annot_level_2', 'annot_level_3_rev2', diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index ae8e0ed..8195f95 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -147,7 +147,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - # utils.notify_backend(self._webhook, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -169,14 +169,13 @@ def _eval_mapping(self): reference_latent.obs = self._reference_adata.obs #Calculate mapping uncertainty and write into .obs - self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key_list, False) - classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key_list, False) + self.knn_ref_trainer= classification_uncert_euclidean(self._configuration, reference_latent, query_latent, self._query_adata, "X", self._cell_type_key_list, True) + classification_uncert_mahalanobis(self._configuration, reference_latent, query_latent, self._query_adata, self._cell_type_key_list, True) #stress score if self._atlas=="hnoca": print("calculating stress score") stress_score(self._query_adata) - print(self._query_adata.obs["Hallmark_Glycolysis_Score"]) def _transfer_labels(self): if not self._clf_native and not self._clf_knn and not self._clf_xgb: @@ -189,7 +188,7 @@ def _transfer_labels(self): if self._clf_native: clf = Classifiers(self._clf_xgb, self._clf_knn, self._model, self._model.__class__) - for cell_type_key in self._cell_type_key: + for cell_type_key in self._cell_type_key_list: self.percent_unknown = clf.predict_labels(self._query_adata, query_latent, self._temp_clf_model_path, self._temp_clf_encoding_path, cell_type_key) @@ -198,24 +197,37 @@ def _transfer_labels(self): clf = Classifiers(self._clf_xgb, self._clf_knn, None, self._model.__class__) self._clf_path = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_PATH) - for cell_type_key in self._cell_type_key: - - self._clf_encoding_path = self._clf_path + cell_type_key + "/classifier_encoding.pickle" + for cell_type_key in self._cell_type_key_list: + + if len(self._cell_type_key_list) > 1: + self._clf_encoding_path = self._clf_path + cell_type_key + "/classifier_encoding.pickle" + else: + self._clf_encoding_path = self._clf_path + "classifier_encoding.pickle" #Download classifiers and encoding from GCP if kNN or XGBoost if self._clf_xgb: + + if len(self._cell_type_key_list) > 1: + self._clf_model_path = self._clf_path + cell_type_key + "/classifier_xgb.ubj" + else: + self._clf_model_path = self._clf_path + "classifier_xgb.ubj" + self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) self._temp_clf_model_path = tempfile.mktemp(suffix=".ubj") - self._clf_model_path = self._clf_path + cell_type_key + "/classifier_xgb.ubj" fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) elif self._clf_knn: + + if len(self._cell_type_key_list) > 1: + self._clf_model_path = self._clf_path + cell_type_key + "/classifier_knn.pickle" + else: + self._clf_model_path = self._clf_path + "classifier_knn.pickle" + self._temp_clf_encoding_path = tempfile.mktemp(suffix=".pickle") fetch_file_from_s3(self._clf_encoding_path, self._temp_clf_encoding_path) - self._clf_model_path = self._clf_path + cell_type_key + "/classifier_knn.pickle" self._temp_clf_model_path = tempfile.mktemp(suffix=".pickle") fetch_file_from_s3(self._clf_model_path, self._temp_clf_model_path) @@ -280,6 +292,8 @@ def _concat_data(self): self._reference_adata.obs[cell_type_key + 'prediction_knn'] = pandas.Series(dtype="category") self._reference_adata.obs[cell_type_key + "_prediction_scanvi"] = pandas.Series(dtype="category") + self._query_adata.obs[cell_type_key] = pandas.Series(dtype="category") + #Create temp files on disk temp_reference = tempfile.NamedTemporaryFile(suffix=".h5ad") temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") @@ -329,15 +343,10 @@ def _compute_latent_representation(self, explicit_representation): def _save_data(self): # add .X to self._combined_adata - print(self._combined_adata.obs.columns) - print("adding X from cloud") self.add_X_from_cloud() - print(self._combined_adata.obs.columns) - combined_downsample = self.downsample_adata() - print(self._combined_adata.obs.columns) # Calculate presence score @@ -374,7 +383,7 @@ def _save_data(self): self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) print(f"query_with_anchor: {self.query_with_anchor}") - # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) #Save output Postprocess.output(None, combined_downsample, self._configuration) @@ -383,7 +392,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: @@ -462,7 +471,9 @@ def downsample_adata(self, query_ratio=5): total_ref_cells_to_sample = len(query_adata_index) * query_ratio # Get unique cell types - celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) + # celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) + celltypes = self._combined_adata.obs[self._cell_type_key].unique() + print(celltypes) # Calculate the proportion of each cell type in the reference data celltype_proportions = {celltype: np.sum(ref_adata.obs[self._cell_type_key] == celltype) / len(ref_adata) for celltype in celltypes} @@ -481,7 +492,8 @@ def downsample_adata(self, query_ratio=5): sampled_cell_index.extend(sampled_cells) else: # Old approach: Sample 10% from each cell type in the reference data - celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) + celltypes = self._combined_adata.obs[self._cell_type_key].unique() + # celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) percentage = 0.02 if ref_adata.n_obs> 3000000 else 0.1 # max 1 sampled_cell_index = np.concatenate([np.random.choice(np.where(ref_adata.obs[self._cell_type_key] == celltype)[0], size=int(len(np.where(ref_adata.obs[self._cell_type_key] == celltype)[0]) * percentage), replace=False) for celltype in celltypes]) diff --git a/mapping/scarches_api/uncert/uncert_metric.py b/mapping/scarches_api/uncert/uncert_metric.py index 40b6c45..d371504 100644 --- a/mapping/scarches_api/uncert/uncert_metric.py +++ b/mapping/scarches_api/uncert/uncert_metric.py @@ -83,8 +83,7 @@ def classification_uncert_mahalanobis( configuration, adata_ref_latent, adata_query_latent, - adata_query_raw, - embedding_name, + adata_query_raw, cell_type_key_list, pretrained=True ): @@ -93,18 +92,19 @@ def classification_uncert_mahalanobis( if pretrained: atlas = get_from_config(configuration, utils.parameters.ATLAS) - cloud_model_path = get_from_config(configuration, utils.parameters.PRETRAINED_MODEL_PATH)[:-len("model.pt")] + model_id = get_from_config(configuration, utils.parameters.MODEL_ID) + cloud_model_path = "models/" + model_id + "/uncertainty/" + cell_type_key + "_mahalanobis_distance.pickle" + uncert_model_path = "./" +atlas + "_mahalanobis_distance.pickle" - uncert_model_path = "/models/" + atlas + "/uncertainty/" + cell_type_key + "_mahalanobis_distance.pickle" fetch_file_from_s3(cloud_model_path, uncert_model_path) with open(uncert_model_path, "rb") as file: gmm = pickle.load(file) else: - gmm = train_mahalanobis(None, adata_ref_latent, cell_type_key, pretrained) + gmm = train_mahalanobis(None, adata_ref_latent, cell_type_key, pretrained=pretrained) centroids = gmm.means_ - cluster_membership = gmm.predict_proba(adata_query_latent[embedding_name]) + cluster_membership = gmm.predict_proba(adata_query_latent.X) uncertainties = pd.DataFrame(columns=["uncertainty"], index=adata_query_latent.obs_names) for query_cell_index, query_cell in enumerate(adata_query_latent.X): @@ -146,15 +146,16 @@ def classification_uncert_euclidean( #Load model if pretrained: atlas = get_from_config(configuration, utils.parameters.ATLAS) - cloud_model_path = get_from_config(configuration, utils.parameters.PRETRAINED_MODEL_PATH)[:-len("model.pt")] - uncert_model_path = "/models/" + atlas + "/euclidian_distance.pickle" + model_id = get_from_config(configuration, utils.parameters.MODEL_ID) + cloud_model_path = "models/" + model_id + "/uncertainty/euclidian_distance.pickle" + uncert_model_path = "./" +atlas + "_euclidian_distance.pickle" fetch_file_from_s3(cloud_model_path, uncert_model_path) with open(uncert_model_path, "rb") as file: trainer = pickle.load(file) else: - trainer = train_euclidian(None, adata_ref_latent, embedding_name, pretrained) + trainer = train_euclidian(None, adata_ref_latent, embedding_name, pretrained=pretrained) #Make sure cell_type is all strings (comparison with NaN and string doesnt work) for cell_type_key in cell_type_key_list: @@ -375,13 +376,10 @@ def train_euclidian(atlas, adata_ref_latent, embedding_name, pretrained, n_neigh else: return trainer -def train_mahalanobis(atlas, adata_ref_latent, embedding_name, cell_type_key, pretrained): +def train_mahalanobis(atlas, adata_ref_latent, cell_type_key, pretrained): num_clusters = adata_ref_latent.obs[cell_type_key].nunique() - if embedding_name == "X": - train_emb = adata_ref_latent.X - elif embedding_name in adata_ref_latent.obsm.keys(): - train_emb = adata_ref_latent.obsm[embedding_name] + train_emb = adata_ref_latent.X #Required too much RAM # gmm = GaussianMixture(n_components=num_clusters) diff --git a/mapping/scarches_api/utils/metrics.py b/mapping/scarches_api/utils/metrics.py index 3fb9d5c..ecef963 100644 --- a/mapping/scarches_api/utils/metrics.py +++ b/mapping/scarches_api/utils/metrics.py @@ -338,15 +338,15 @@ def entropy_of_labels(indices): stat = np.clip(stat, 0.00, 5.00) return stat -def percentage_unknown(query, prediction_label, uncertainty_threshold=0.5): +def percentage_unknown(query, cell_type_key, prediction_label, uncertainty_threshold=0.5): query.obs[f"{prediction_label}_filtered_by_uncert>0.5"] = query.obs[ prediction_label ].mask( - query.obs["uncertainty_mahalanobis"] > uncertainty_threshold, + query.obs[f"{cell_type_key}_uncertainty_euclidean"] > uncertainty_threshold, "Unknown", ) - number_unknown = (query.obs["uncertainty_mahalanobis"] > uncertainty_threshold).sum() + number_unknown = (query.obs[f"{cell_type_key}_uncertainty_euclidean"] > uncertainty_threshold).sum() return number_unknown/len(query)*100 diff --git a/mapping/scarches_api/utils/parameters.py b/mapping/scarches_api/utils/parameters.py index f93e593..b7ee254 100644 --- a/mapping/scarches_api/utils/parameters.py +++ b/mapping/scarches_api/utils/parameters.py @@ -106,3 +106,4 @@ # USE_GPU = 'use_gpu' # sets the used atlas ATLAS = 'atlas' +MODEL_ID = "model_id" From 79e6027914b518f3d4ebf0270245392ca7460e2e Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 25 Jun 2024 17:17:07 +0200 Subject: [PATCH 48/63] print model id --- mapping/scarches_api/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 8195f95..bebaf8b 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -39,6 +39,7 @@ def __init__(self, configuration) -> None: self._atlas = get_from_config(configuration=configuration, key=parameters.ATLAS) self._model_type = get_from_config(configuration=configuration, key=parameters.MODEL) + self._model_id = get_from_config(configuration=configuration, key=parameters.MODEL_ID) self._model_path = get_from_config(configuration=configuration, key=parameters.PRETRAINED_MODEL_PATH) self._scpoli_attr = get_from_config(configuration=configuration, key=parameters.SCPOLI_ATTR) self._scpoli_model_params = get_from_config(configuration=configuration, key=parameters.SCPOLI_MODEL_PARAMS) @@ -49,6 +50,7 @@ def __init__(self, configuration) -> None: self._webhook_metrics = utils.get_from_config(configuration, parameters.WEBHOOK_METRICS) # self._use_gpu = get_from_config(configuration=configuration, key=parameters.USE_GPU) + print(f"model_id: {self._model_id}") #Has to be empty for the load_query_data function to work properly (looking for "model.pt") self._temp_model_path = "" self._model = None From afdd270a70d9edcdb42389ee77085cdceee3a1f6 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 25 Jun 2024 18:50:06 +0200 Subject: [PATCH 49/63] fix bug --- mapping/scarches_api/models.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index bebaf8b..88d81c2 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -68,6 +68,8 @@ def __init__(self, configuration) -> None: self._cell_type_key = None self._batch_key = None self._unlabeled_key = None + self.cell_type_key_input = "cell_type" + self.batch_key_input = "batch" # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) self._cell_type_key, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) @@ -75,6 +77,12 @@ def __init__(self, configuration) -> None: if self._cell_type_key_list is None: self._cell_type_key_list = [self._cell_type_key] + # self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input] + + # if self._cell_type_key_input in self._query_adata.obs.columns: + # self._query_adata.obs[self._cell_type_key] = self._query_adata.obs[self.cell_type_key_input] + # self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") @@ -100,7 +108,7 @@ def _map_query(self): # threshold = 10000 if self._atlas == "fetal_brain": lr=0.1 - # self._max_epochs = 40 + self._max_epochs = 40 else: lr=0.001 @@ -149,7 +157,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - utils.notify_backend(self._webhook, {"ratio":ratio}) + # utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -385,7 +393,7 @@ def _save_data(self): self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) print(f"query_with_anchor: {self.query_with_anchor}") - utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) + # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) #Save output Postprocess.output(None, combined_downsample, self._configuration) @@ -394,7 +402,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: @@ -555,7 +563,7 @@ def _compute_latent_representation(self, explicit_representation): class ScANVI(ArchmapBaseModel): def _map_query(self, supervised=False): - #Align genes and gene order to model + #Align genes and gene order to model if self._cell_type_key in self._query_adata.obs.columns: self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) From f3ab38dfdc259ea08493eccc352feb4ae9dd2ae3 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 25 Jun 2024 19:38:29 +0200 Subject: [PATCH 50/63] fix bug --- mapping/scarches_api/uncert/uncert_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/uncert/uncert_metric.py b/mapping/scarches_api/uncert/uncert_metric.py index d371504..a04b1c0 100644 --- a/mapping/scarches_api/uncert/uncert_metric.py +++ b/mapping/scarches_api/uncert/uncert_metric.py @@ -163,7 +163,7 @@ def classification_uncert_euclidean( for col in cell_type_cols: adata_ref_latent.obs[col] = adata_ref_latent.obs[col].astype(str) - _, uncertainties = sca.utils.weighted_knn_transfer( + _, uncertainties, _ = sca.utils.weighted_knn_transfer( adata_query_latent, embedding_name, adata_ref_latent.obs, From a223478b6d0d94d07ac32b7dda6ba91b5d731c2a Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Tue, 25 Jun 2024 19:41:11 +0200 Subject: [PATCH 51/63] uncomment --- mapping/scarches_api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 88d81c2..64514d5 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -157,7 +157,7 @@ def _acquire_data(self): ratio = inter_len / len(ref_vars) print(ratio) - # utils.notify_backend(self._webhook, {"ratio":ratio}) + utils.notify_backend(self._webhook, {"ratio":ratio}) # save only necessary data for mapping to new adata @@ -393,7 +393,7 @@ def _save_data(self): self.query_with_anchor=percent_query_with_anchor(adjs["r2q"], adjs["q2r"]) print(f"query_with_anchor: {self.query_with_anchor}") - # utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) + utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) #Save output Postprocess.output(None, combined_downsample, self._configuration) @@ -402,7 +402,7 @@ def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: - # utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) + utils.notify_backend(get_from_config(self._configuration, parameters.WEBHOOK), self._configuration) if not self._reference_adata_path.endswith("data.h5ad"): raise ValueError("The reference data should be named data.h5ad") else: From 0effefc92fc878be55f8a513a71098ae98693b81 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 27 Jun 2024 09:58:42 +0200 Subject: [PATCH 52/63] change batch key --- mapping/process/processing.py | 12 +++++++++--- mapping/scarches_api/models.py | 19 +++++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index f598371..49fee06 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -412,6 +412,9 @@ def get_keys(atlas, target_adata, configuration): elif atlas == "fetal_brain": cell_type_key = "subregion_class" batch_key = "batch" + elif atlas == "hnoca_extended": + cell_type_key = "cell_type" + batch_key = "batch" model_type = utils.get_from_config(configuration, parameters.MODEL) @@ -440,9 +443,12 @@ def get_keys(atlas, target_adata, configuration): # condition_key_model = attr_dict["condition_keys_"][-1] unlabeled_key_model = "unlabeled" - #Check if provided query contains respective labels - if cell_type_key not in target_adata.obs.columns or batch_key not in target_adata.obs.columns: - raise ValueError("Please double check if cell_type and batch keys in query match the requirements stated on the website") + batch_key_input = "batch" + + # Check if provided query contains batch labels + if batch_key_input not in target_adata.obs.columns: + raise ValueError("Batch key information not specified. Please make sure your batch key is labelled 'batch' in your query data.") + return cell_type_key, cell_type_key_list, batch_key, unlabeled_key_model diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 64514d5..ce7ccd5 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -77,11 +77,8 @@ def __init__(self, configuration) -> None: if self._cell_type_key_list is None: self._cell_type_key_list = [self._cell_type_key] - # self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input] - - # if self._cell_type_key_input in self._query_adata.obs.columns: - # self._query_adata.obs[self._cell_type_key] = self._query_adata.obs[self.cell_type_key_input] - # self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input].copy() + self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") @@ -356,6 +353,10 @@ def _save_data(self): print("adding X from cloud") self.add_X_from_cloud() + del self._combined_adata.obs[self.batch_key_input] + + self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) + combined_downsample = self.downsample_adata() # Calculate presence score @@ -394,10 +395,12 @@ def _save_data(self): print(f"query_with_anchor: {self.query_with_anchor}") utils.notify_backend(self._webhook_metrics, {"clust_pres_score":self.clust_pres_score, "query_with_anchor":self.query_with_anchor, "percentage_unknown": self.percent_unknown}) - + #Save output Postprocess.output(None, combined_downsample, self._configuration) + + def add_X_from_cloud(self): if True or get_from_config(self._configuration, parameters.WEBHOOK) is not None and len( get_from_config(self._configuration, parameters.WEBHOOK)) > 0: @@ -564,8 +567,8 @@ def _compute_latent_representation(self, explicit_representation): class ScANVI(ArchmapBaseModel): def _map_query(self, supervised=False): #Align genes and gene order to model - if self._cell_type_key in self._query_adata.obs.columns: - self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] + # if self._cell_type_key in self._query_adata.obs.columns: + # self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) self._query_adata.var_names_make_unique() From 3c71ab6f05a061dfc47070e2a99507b644dad905 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 27 Jun 2024 11:57:42 +0200 Subject: [PATCH 53/63] fix batch column --- mapping/scarches_api/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index ce7ccd5..4beb3e6 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -80,6 +80,8 @@ def __init__(self, configuration) -> None: self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input].copy() self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + del self._query_adata.obs[self.batch_key_input] + self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") self._clf_xgb = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("XGBoost") @@ -350,13 +352,11 @@ def _compute_latent_representation(self, explicit_representation): def _save_data(self): # add .X to self._combined_adata + self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) + print("adding X from cloud") self.add_X_from_cloud() - del self._combined_adata.obs[self.batch_key_input] - - self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) - combined_downsample = self.downsample_adata() # Calculate presence score From d51958517439823d0f5dbebff354c76eba4935d5 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 27 Jun 2024 14:11:38 +0200 Subject: [PATCH 54/63] fix batch key issue --- mapping/scarches_api/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 4beb3e6..e5555bf 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -77,10 +77,12 @@ def __init__(self, configuration) -> None: if self._cell_type_key_list is None: self._cell_type_key_list = [self._cell_type_key] - self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input].copy() + self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) - del self._query_adata.obs[self.batch_key_input] + if self.batch_key_input != self._batch_key: + del self._query_adata.obs[self.batch_key_input] + self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input].copy() self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") From ce1bece13498a392f965b931ee9db8bd6c5f4bf7 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 27 Jun 2024 15:22:28 +0200 Subject: [PATCH 55/63] add hnoca extended name --- mapping/scarches_api/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mapping/scarches_api/utils/utils.py b/mapping/scarches_api/utils/utils.py index 24c43a6..1abe024 100644 --- a/mapping/scarches_api/utils/utils.py +++ b/mapping/scarches_api/utils/utils.py @@ -593,6 +593,8 @@ def translate_atlas_to_directory(configuration): return "heoca" elif atlas == "Fetal Brain": return "fetal_brain" + elif atlas == "HNOCA Extended": + return "hnoca_extended" def set_keys(configuration): From 4f783d90237964a5586faf7343ddfc866878ff64 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Fri, 28 Jun 2024 10:19:20 +0200 Subject: [PATCH 56/63] change cell type key hnoca extended --- mapping/process/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 49fee06..69ed6ca 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -413,7 +413,7 @@ def get_keys(atlas, target_adata, configuration): cell_type_key = "subregion_class" batch_key = "batch" elif atlas == "hnoca_extended": - cell_type_key = "cell_type" + cell_type_key = "annot_level_2_extended" batch_key = "batch" From 80206d0491584720d737b439de248f0240d0afe6 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 1 Jul 2024 14:34:58 +0200 Subject: [PATCH 57/63] fix cell type key bug --- mapping/process/processing.py | 12 +++++++----- mapping/scarches_api/models.py | 27 ++++++++++++++++----------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 69ed6ca..93302a8 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -363,6 +363,7 @@ def get_keys(atlas, target_adata, configuration): condition_key_model = None unlabeled_key_model = None cell_type_key_list = None + cell_type_key_classifier = None if atlas == 'pbmc': cell_type_key = 'cell_type_for_integration' @@ -398,8 +399,8 @@ def get_keys(atlas, target_adata, configuration): cell_type_key = "cell_type_scarches" batch_key = "batch_donor_asset" elif atlas == "hnoca": - cell_type_key ="annot_level_2" #"snapseed_pca_rss_level_123" - cell_type_key_list = ['annot_level_1', + cell_type_key =["snapseed_pca_rss_level_1","snapseed_pca_rss_level_12","snapseed_pca_rss_level_123"] + cell_type_key_classifier = ['annot_level_1', 'annot_level_2', 'annot_level_3_rev2', 'annot_level_4_rev2', @@ -413,10 +414,11 @@ def get_keys(atlas, target_adata, configuration): cell_type_key = "subregion_class" batch_key = "batch" elif atlas == "hnoca_extended": - cell_type_key = "annot_level_2_extended" + cell_type_key =["snapseed_pca_rss_level_1","snapseed_pca_rss_level_12","snapseed_pca_rss_level_123"] + cell_type_key_classifier = "annot_level_2_extended" batch_key = "batch" - + model_type = utils.get_from_config(configuration, parameters.MODEL) if model_type in ["scANVI","scVI"]: @@ -450,7 +452,7 @@ def get_keys(atlas, target_adata, configuration): raise ValueError("Batch key information not specified. Please make sure your batch key is labelled 'batch' in your query data.") - return cell_type_key, cell_type_key_list, batch_key, unlabeled_key_model + return cell_type_key, cell_type_key_classifier, cell_type_key_list, batch_key, unlabeled_key_model def __get_keys_user(configuration): #Get parameters from user input diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index de64e83..9e75d4f 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -68,17 +68,26 @@ def __init__(self, configuration) -> None: self._cell_type_key = None self._batch_key = None self._unlabeled_key = None - self.cell_type_key_input = "cell_type" + self.cell_type_key_input = "user_cell_type" self.batch_key_input = "batch" # self._cell_type_key, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata) - self._cell_type_key, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) + self._cell_type_key, self._cell_type_key_classifier, self._cell_type_key_list, self._batch_key, self._unlabeled_key = Preprocess.get_keys(self._atlas, self._query_adata, configuration) - if self._cell_type_key_list is None: - self._cell_type_key_list = [self._cell_type_key] + if isinstance(self._cell_type_key,list): + for key in self._cell_type_key: + self._query_adata.obs[key] = [self._unlabeled_key]*len(self._query_adata) + else: + self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) - - self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) + if self._cell_type_key_classifier is None: + self._cell_type_key_classifier = self._cell_type_key + + if self._cell_type_key_list is None: + if isinstance(self._cell_type_key_classifier,list): + self._cell_type_key_list = self._cell_type_key_classifier + else: + self._cell_type_key_list = [self._cell_type_key_classifier] if self.batch_key_input != self._batch_key: del self._query_adata.obs[self.batch_key_input] @@ -569,11 +578,7 @@ def _compute_latent_representation(self, explicit_representation): class ScANVI(ArchmapBaseModel): def _map_query(self, supervised=False): - #Align genes and gene order to model - # if self._cell_type_key in self._query_adata.obs.columns: - # self._query_adata.obs[f"{self._cell_type_key}_user_input"] = self._query_adata.obs[self._cell_type_key] - self._query_adata.obs[self._cell_type_key] = [self._unlabeled_key]*len(self._query_adata) - + self._query_adata.var_names_make_unique() scarches.models.SCANVI.prepare_query_anndata(self._query_adata, self._temp_model_path) From 205e6e0efdb66381468a04d497b35f3b41a4363d Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 1 Jul 2024 15:08:53 +0200 Subject: [PATCH 58/63] fix bug --- mapping/scarches_api/models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 9e75d4f..bb5a0fb 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -488,6 +488,11 @@ def downsample_adata(self, query_ratio=5): ref_adata = self._combined_adata[self._combined_adata.obs["query"] == "0"] query_adata_index = np.where(self._combined_adata.obs["query"] == "1")[0] + if isinstance(self._cell_type_key,list): + celltype_key = self._cell_type_key[0] + else: + celltype_key = self._cell_type_key + # Check if 10% of reference is less than query size times the ratio if len(ref_adata) * 0.1 < len(query_adata_index) * query_ratio: # New approach: Proportional sampling based on cell type proportions @@ -496,16 +501,17 @@ def downsample_adata(self, query_ratio=5): # Get unique cell types # celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) - celltypes = self._combined_adata.obs[self._cell_type_key].unique() - print(celltypes) + + + celltypes = self._combined_adata.obs[celltype_key].unique() # Calculate the proportion of each cell type in the reference data - celltype_proportions = {celltype: np.sum(ref_adata.obs[self._cell_type_key] == celltype) / len(ref_adata) for celltype in celltypes} + celltype_proportions = {celltype: np.sum(ref_adata.obs[celltype_key] == celltype) / len(ref_adata) for celltype in celltypes} # Sample cells from each cell type according to its proportion sampled_cell_index = [] for celltype, proportion in celltype_proportions.items(): - cell_indices = np.where(ref_adata.obs[self._cell_type_key] == celltype)[0] + cell_indices = np.where(ref_adata.obs[celltype_key] == celltype)[0] sample_size = int(total_ref_cells_to_sample * proportion) # Adjust sample size if it exceeds the number of available cells @@ -516,11 +522,11 @@ def downsample_adata(self, query_ratio=5): sampled_cell_index.extend(sampled_cells) else: # Old approach: Sample 10% from each cell type in the reference data - celltypes = self._combined_adata.obs[self._cell_type_key].unique() + celltypes = self._combined_adata.obs[celltype_key].unique() # celltypes = np.unique(self._combined_adata.obs[self._cell_type_key]) percentage = 0.02 if ref_adata.n_obs> 3000000 else 0.1 # max 1 - sampled_cell_index = np.concatenate([np.random.choice(np.where(ref_adata.obs[self._cell_type_key] == celltype)[0], size=int(len(np.where(ref_adata.obs[self._cell_type_key] == celltype)[0]) * percentage), replace=False) for celltype in celltypes]) + sampled_cell_index = np.concatenate([np.random.choice(np.where(ref_adata.obs[celltype_key] == celltype)[0], size=int(len(np.where(ref_adata.obs[celltype_key] == celltype)[0]) * percentage), replace=False) for celltype in celltypes]) # Combine sampled reference cells with query cells sampled_cell_index = np.concatenate([sampled_cell_index, query_adata_index]) From f42d57273d80e0779966cb14a1797c5d5c0bd701 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 1 Jul 2024 19:59:02 +0200 Subject: [PATCH 59/63] fix batch error --- mapping/scarches_api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index bb5a0fb..a83b633 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -90,8 +90,8 @@ def __init__(self, configuration) -> None: self._cell_type_key_list = [self._cell_type_key_classifier] if self.batch_key_input != self._batch_key: - del self._query_adata.obs[self.batch_key_input] self._query_adata.obs[self._batch_key] = self._query_adata.obs[self.batch_key_input].copy() + del self._query_adata.obs[self.batch_key_input] self._clf_native = get_from_config(configuration=self._configuration, key=parameters.CLASSIFIER_TYPE).pop("Native") From d0b469fdbf6a37e365038f2b51a51a2a08606f5e Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 1 Jul 2024 23:59:23 +0200 Subject: [PATCH 60/63] fix hnoca issue --- mapping/scarches_api/models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index a83b633..89ae782 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -363,7 +363,8 @@ def _compute_latent_representation(self, explicit_representation): def _save_data(self): # add .X to self._combined_adata - self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) + if self.batch_key_input != self._batch_key: + self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) print("adding X from cloud") self.add_X_from_cloud() @@ -488,10 +489,10 @@ def downsample_adata(self, query_ratio=5): ref_adata = self._combined_adata[self._combined_adata.obs["query"] == "0"] query_adata_index = np.where(self._combined_adata.obs["query"] == "1")[0] - if isinstance(self._cell_type_key,list): - celltype_key = self._cell_type_key[0] + if isinstance(self._cell_type_key_classifier,list): + celltype_key = self._cell_type_key_classifier[0] else: - celltype_key = self._cell_type_key + celltype_key = self._cell_type_key_classifier # Check if 10% of reference is less than query size times the ratio if len(ref_adata) * 0.1 < len(query_adata_index) * query_ratio: From 64ef413e6e6e3cf128d04a7a992e97fd87743ad3 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Thu, 4 Jul 2024 09:10:23 +0200 Subject: [PATCH 61/63] downsample pbmc for download --- mapping/scarches_api/models.py | 23 +++++++++++++++++++---- mapping/scarches_api/utils/utils.py | 18 +++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mapping/scarches_api/models.py b/mapping/scarches_api/models.py index 89ae782..e52a36b 100644 --- a/mapping/scarches_api/models.py +++ b/mapping/scarches_api/models.py @@ -367,9 +367,12 @@ def _save_data(self): self._combined_adata.obs = self._combined_adata.obs.rename(columns={self._batch_key : self.batch_key_input}) print("adding X from cloud") - self.add_X_from_cloud() + count_matrix_size_gb = self.add_X_from_cloud() - combined_downsample = self.downsample_adata() + if count_matrix_size_gb<40: + combined_downsample = self.downsample_adata() + else: + combined_downsample = self._combined_adata.copy() # Calculate presence score @@ -460,16 +463,28 @@ def add_X_from_cloud(self): combined_adata.X = combined_data_X.X sc.write(self.temp_output_combined, combined_adata) - else: + elif count_matrix_size_gb>=10 and count_matrix_size_gb<40: print("Count matrix size larger than 10 gb.") temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") self.adata_query_X.write_h5ad(temp_query.name) del self.adata_query_X gc.collect() self.temp_output_combined =replace_X_on_disk(combined_adata,self.temp_output_combined, temp_query.name, count_matrix_path) - + combined_adata = sc.read(self.temp_output_combined) + else: + temp_query = tempfile.NamedTemporaryFile(suffix=".h5ad") + self.adata_query_X.write_h5ad(temp_query.name) + del self.adata_query_X + gc.collect() + count_matrix_downsample_path = self._reference_adata_path[:-len("data.h5ad")] + "data_count_downsample.h5ad" + # download downsampled counts from cloud and concat + self.temp_output_combined =replace_X_on_disk(combined_adata,self.temp_output_combined, temp_query.name, count_matrix_downsample_path, use_downsample=True) + combined_adata = sc.read(self.temp_output_combined) + self._combined_adata = combined_adata + return count_matrix_size_gb + def downsample_adata(self, query_ratio=5): """ diff --git a/mapping/scarches_api/utils/utils.py b/mapping/scarches_api/utils/utils.py index 1abe024..01f8ff0 100644 --- a/mapping/scarches_api/utils/utils.py +++ b/mapping/scarches_api/utils/utils.py @@ -671,7 +671,7 @@ def fetch_file_to_temp_path_from_s3(key): return filename -def replace_X_on_disk(combined_adata,temp_output, query_X_file, ref_count_matrix_path): +def replace_X_on_disk(combined_adata,temp_output, query_X_file, ref_count_matrix_path, use_downsample=False): """ Writes combined_adata to disk, fetches another .h5ad file specified by ref_count_matrix_path. Concatenates the .X of the fetched file with query_X_file. @@ -685,11 +685,23 @@ def replace_X_on_disk(combined_adata,temp_output, query_X_file, ref_count_matrix Returns: File path to saved adata with concatenated metadata and .X """ + + temp_ref_count_matrix_path = fetch_file_to_temp_path_from_s3(ref_count_matrix_path) + # Fetch the new file and get its path + if use_downsample: + # get obs index + with h5py.File(temp_ref_count_matrix_path, "r") as f: + obs_df=read_elem(f["obs"]) + obs_index = obs_df.index + + # add query index + query_index = combined_adata.obs[combined_adata.obs["query"] == "1"].index + obs_index_combined =obs_index.append(query_index) + combined_adata = combined_adata[obs_index_combined] + # Write combined_adata to disk combined_adata.write(temp_output) print(f"combined_adata written to {temp_output}") - # Fetch the new file and get its path - temp_ref_count_matrix_path = fetch_file_to_temp_path_from_s3(ref_count_matrix_path) if temp_ref_count_matrix_path is None: print("No file fetched. Exiting.") return From f870cf988ce1993ccdae824fb9c3ee53a43e8212 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 15 Jul 2024 09:47:02 +0200 Subject: [PATCH 62/63] add finer cell types hlca --- mapping/process/processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index 93302a8..ae33d63 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -373,6 +373,7 @@ def get_keys(atlas, target_adata, configuration): batch_key = 'donor' elif atlas == 'hlca': cell_type_key = 'scanvi_label' + cell_type_key_classifier = ["ann_level_3","ann_level_4"] batch_key = 'dataset' elif atlas == 'hlca_retrained': cell_type_key = 'ann_finest_level' From ee3cea2bf0615f0f8d70ee28ab48427ea2475e16 Mon Sep 17 00:00:00 2001 From: ChelseaBright96 Date: Mon, 15 Jul 2024 22:28:50 +0200 Subject: [PATCH 63/63] add ann_level_5 --- mapping/process/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapping/process/processing.py b/mapping/process/processing.py index ae33d63..91873e5 100644 --- a/mapping/process/processing.py +++ b/mapping/process/processing.py @@ -373,7 +373,7 @@ def get_keys(atlas, target_adata, configuration): batch_key = 'donor' elif atlas == 'hlca': cell_type_key = 'scanvi_label' - cell_type_key_classifier = ["ann_level_3","ann_level_4"] + cell_type_key_classifier = ["ann_level_3","ann_level_4","ann_level_5"] batch_key = 'dataset' elif atlas == 'hlca_retrained': cell_type_key = 'ann_finest_level'