diff --git a/setup.py b/setup.py index b3c5f5ec..f70ff940 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,8 @@ def get_version(rel_path: str) -> str: long_description_content_type="text/markdown", url='https://github.com/STOmics/Stereopy', author='STOmics', - author_email='xujunhao@genomics.cn', + author_email='tanliwei@stomics.tech', + license='MIT License', python_requires='>=3.8,<3.9', install_requires=[ l.strip() for l in Path('requirements.txt').read_text('utf-8').splitlines() diff --git a/stereo/algorithm/cell_cell_communication/main.py b/stereo/algorithm/cell_cell_communication/main.py index 183c89ce..5ddfab7e 100644 --- a/stereo/algorithm/cell_cell_communication/main.py +++ b/stereo/algorithm/cell_cell_communication/main.py @@ -328,7 +328,10 @@ def _prepare_data(self, cluster_res_key): cluster.rename({'group': 'cell_type'}, axis=1, inplace=True) cluster.set_index('bins', drop=True, inplace=True) cluster.index.name = 'cell' - data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T.toarray()) + if self.stereo_exp_data.issparse(): + data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T.toarray()) + else: + data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T) data.columns = self.stereo_exp_data.cell_names.astype(str) data.index = self.stereo_exp_data.gene_names return data, cluster diff --git a/stereo/algorithm/co_occurrence.py b/stereo/algorithm/co_occurrence.py index e92d311d..0752e6dd 100644 --- a/stereo/algorithm/co_occurrence.py +++ b/stereo/algorithm/co_occurrence.py @@ -11,6 +11,7 @@ from stereo.core.stereo_exp_data import AnnBasedStereoExpData # module in self project from stereo.core.stereo_exp_data import StereoExpData +from stereo.core.ms_data import MSData # ----------------------------------------------# @@ -286,7 +287,7 @@ def co_occurrence( return ret @staticmethod - def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'): + def ms_co_occur_integrate(ms_data: MSData, scope, use_col, res_key='co_occurrence'): from collections import Counter if use_col not in ms_data.obs: tmp_list = [] @@ -296,11 +297,13 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'): ms_data.obs[use_col] = ms_data.obs[use_col].astype('category') slice_groups = scope.split('|') + slice_index = [] if len(slice_groups) == 1: slices = slice_groups[0].split(",") ct_count = {} for x in slices: ct_count[x] = dict(Counter(ms_data[x].cells[use_col])) + slice_index.append(ms_data.names.index(x)) ct_count = pd.DataFrame(ct_count) ct_ratio = ct_count.div(ct_count.sum(axis=1), axis=0) @@ -318,6 +321,7 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'): ct_count = {} for x in slices: ct_count[x] = dict(Counter(ms_data[x].cells[use_col])) + slice_index.append(ms_data.names.index(x)) ct_count = pd.DataFrame(ct_count) ct_ratio = ct_count.div(ct_count.sum(axis=1), axis=0) @@ -333,7 +337,10 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'): merge_co_occur_ret = {ct: ret[0][ct] - ret[1][ct] for ct in merge_co_occur_ret} else: - print('co-occurrence only compare case and control on two groups') - merge_co_occur_ret = None + raise Exception('co-occurrence only compare case and control on two groups') + # merge_co_occur_ret = None - return merge_co_occur_ret + # return merge_co_occur_ret + slice_index = np.unique(slice_index) + scope_key = "scope_[" + ",".join([str(i) for i in slice_index]) + "]" + ms_data.tl.result[scope_key][res_key] = merge_co_occur_ret diff --git a/stereo/algorithm/regulatory_network_inference/main.py b/stereo/algorithm/regulatory_network_inference/main.py index 35d951f4..af115413 100644 --- a/stereo/algorithm/regulatory_network_inference/main.py +++ b/stereo/algorithm/regulatory_network_inference/main.py @@ -58,10 +58,14 @@ def main(self, seed: int = None, cache: bool = False, cache_res_key: str = 'regulatory_network_inference', - save: bool = True, + save_regulons: bool = True, + save_loom: bool = False, + fn_prefix: str = None, method: str = 'grnboost', ThreeD_slice: bool = False, - prune_kwargs: dict = {} + prune_kwargs: dict = {}, + hotspot_kwargs: dict = {}, + use_raw: bool = True ): """ Enables researchers to infer transcription factors (TFs) and gene regulatory networks. @@ -76,14 +80,27 @@ def main(self, :param res_key: the key for storage of inference regulatory network result. :param seed: optional random seed for the regressors. Default None. :param cache: whether to use cache files. Need to provide adj.csv, motifs.csv and auc.csv. - :param save: whether to save the result as a file. + :param save_regulons: whether to save regulons into a csv file. + :param save_loom: whether to save the result as a loom file. + :param fn_prefix: the prefix of file name for saving regulons or loom. :param method: the method to inference GRN, 'grnboost' or 'hotspot'. :param ThreeD_slice: whether to use 3D slice data. - :param prune_kwargs: dict, others parameters of pyscenic.prune.prune2df + :param prune_kwargs: dict, other parameters of pyscenic.prune.prune2df. + :param hotspot_kwargs: dict, other parameters for 'hotspot' method. :return: Computation result of inference regulatory network is stored in self.result where the result key is 'regulatory_network_inference'. """ # noqa - matrix = self.stereo_exp_data.to_df() - df = self.stereo_exp_data.to_df() + self.use_raw = use_raw + if use_raw and self.stereo_exp_data.raw is None: + raise Exception("The raw data is not found, you need to run 'raw_checkpoint()' first.") + + if use_raw: + logger.info('the raw expression matrix will be used.') + matrix = self.stereo_exp_data.raw.to_df() + else: + logger.info('if you have done some normalized processing, the normalized expression matrix will be used.') + matrix = self.stereo_exp_data.to_df() + # df = self.stereo_exp_data.to_df() + df = matrix.copy(deep=True) if num_workers is None: num_workers = cpu_count() @@ -114,8 +131,13 @@ def main(self, adjacencies = self.grn_inference(matrix, genes=target_genes, tf_names=tfsf, num_workers=num_workers, seed=seed, cache=cache, cache_res_key=cache_res_key) elif method == 'hotspot': + hotspot_kwargs_adjusted = {} + for key, value in hotspot_kwargs.items(): + if key in ('tf_list', 'jobs', 'cache', 'cache_res_key', 'ThreeD_slice'): + continue + hotspot_kwargs_adjusted[key] = value adjacencies = self.hotspot_matrix(tf_list=tfsf, jobs=num_workers, cache=cache, cache_res_key=cache_res_key, - ThreeD_slice=ThreeD_slice) + ThreeD_slice=ThreeD_slice, **hotspot_kwargs_adjusted) modules = self.get_modules(adjacencies, df) # 4. Regulons prediction aka cisTarget @@ -131,26 +153,27 @@ def main(self, 'regulons': self.regulon_dict, 'auc_matrix': auc_matrix, 'adjacencies': adjacencies, - 'motifs': motifs + # 'motifs': motifs } self.stereo_exp_data.tl.reset_key_record('regulatory_network_inference', res_key) - if save: - self.regulons_to_csv(regulons) + if save_regulons: + self.regulons_to_csv(regulons, fn_prefix=fn_prefix) # self.regulons_to_json(regulons) - self.to_loom(df, auc_matrix, regulons) + if save_loom: + self.to_loom(df, auc_matrix, regulons, fn_prefix=fn_prefix) # self.to_cytoscape(regulons, adjacencies, 'Zfp354c') @staticmethod - def input_hotspot(data): + def input_hotspot(counts, position): """ Extract needed information to construct a Hotspot instance from StereoExpData data :param data: :return: a dictionary """ # 3. use dataframe and position array, StereoExpData as well - counts = data.to_df().T # gene x cell - position = data.position + # counts = data.to_df().T # gene x cell + # position = data.position num_umi = counts.sum(axis=0) # total counts for each cell # Filter genes gene_counts = (counts > 0).sum(axis=1) @@ -207,12 +230,16 @@ def hotspot_matrix(self, logger.info('cached file not found, running hotspot now') global hs - data = self.stereo_exp_data - hotspot_data = RegulatoryNetworkInference.input_hotspot(data) + # data = self.stereo_exp_data + if self.use_raw: + counts = self.stereo_exp_data.raw.to_df().T + else: + counts = self.stereo_exp_data.to_df().T + hotspot_data = RegulatoryNetworkInference.input_hotspot(counts, self.stereo_exp_data.position) if ThreeD_slice: - arr2 = data.position_z - position_3D = np.concatenate((data.position, arr2), axis=1) + arr2 = self.stereo_exp_data.position_z + position_3D = np.concatenate((self.stereo_exp_data.position, arr2), axis=1) hotspot_data['position'] = position_3D hs = hotspot.Hotspot.legacy_init(hotspot_data['counts'], @@ -476,7 +503,7 @@ def regulons_to_json(self, regulon_list: list, fn='regulons.json'): with open(fn, 'w') as f: json.dump(regulon_dict, f, indent=4) - def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv'): + def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv', fn_prefix: str = None): """ Save regulon_list (df2regulons output) into a csv file. :param regulon_list: @@ -488,12 +515,21 @@ def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv'): for key in regulon_dict.keys(): regulon_dict[key] = ";".join(regulon_dict[key]) # Write to csv file + if fn_prefix is not None: + fn = f"{fn_prefix}_{fn}" with open(fn, 'w') as f: w = csv.writer(f) w.writerow(["Regulons", "Target_genes"]) w.writerows(regulon_dict.items()) - def to_loom(self, matrix: pd.DataFrame, auc_matrix: pd.DataFrame, regulons: list, loom_fn: str = 'grn_output.loom'): + def to_loom( + self, + matrix: pd.DataFrame, + auc_matrix: pd.DataFrame, + regulons: list, + loom_fn: str = 'grn_output.loom', + fn_prefix: str = None + ): """ Save GRN results in one loom file :param matrix: @@ -502,6 +538,8 @@ def to_loom(self, matrix: pd.DataFrame, auc_matrix: pd.DataFrame, regulons: list :param loom_fn: :return: """ + if fn_prefix is not None: + loom_fn = f"{fn_prefix}_{loom_fn}" export2loom( ex_mtx=matrix, auc_mtx=auc_matrix, diff --git a/stereo/algorithm/regulatory_network_inference/plot_grn.py b/stereo/algorithm/regulatory_network_inference/plot_grn.py index c5bf08cf..457a63d8 100644 --- a/stereo/algorithm/regulatory_network_inference/plot_grn.py +++ b/stereo/algorithm/regulatory_network_inference/plot_grn.py @@ -110,7 +110,7 @@ def grn_dotplot(self, if cluster_res_key in self.stereo_exp_data.cells._obs.columns: meta = pd.DataFrame({ - 'bin': self.stereo_exp_data.cells.cell_name, + 'bins': self.stereo_exp_data.cells.cell_name, 'group': self.stereo_exp_data.cells._obs[cluster_res_key].tolist() }) else: @@ -265,7 +265,7 @@ def auc_heatmap( if network_res_key not in self.pipeline_res: logger.info(f"The result specified by {network_res_key} is not exists.") - fig = sns.clustermap( + g = sns.clustermap( self.pipeline_res[network_res_key]['auc_matrix'], pivot_kws=pivot_kws, method=method, @@ -287,7 +287,7 @@ def auc_heatmap( cbar_pos=cbar_pos, ) - return fig + return g.figure @plot_scale @reorganize_coordinate @@ -552,13 +552,13 @@ def auc_heatmap_by_group( cbar_pos=cbar_pos, ) - return g + return g.figure def spatial_scatter_by_regulon_3D( self, network_res_key: str = 'regulatory_network_inference', reg_name: str = None, - fn: str = None, + # fn: str = None, view_vertical: int = 0, view_horizontal: int = 0, show_axis: bool = False, @@ -581,8 +581,8 @@ def spatial_scatter_by_regulon_3D( elif '(+)' not in reg_name: reg_name = reg_name + '(+)' - if fn is None: - fn = f'{reg_name.strip("(+)")}.pdf' + # if fn is None: + # fn = f'{reg_name.strip("(+)")}.pdf' # prepare plotting data arr2 = self.stereo_exp_data.position_z @@ -621,7 +621,8 @@ def spatial_scatter_by_regulon_3D( plt.box(False) plt.axis('off') plt.colorbar(sc, shrink=0.35) - plt.savefig(fn, format='pdf') + # plt.savefig(fn, format='pdf') + return fig def get_n_hls_colors(num): diff --git a/stereo/algorithm/time_series_analysis.py b/stereo/algorithm/time_series_analysis.py index d5ba3bd4..ebbd6fdd 100644 --- a/stereo/algorithm/time_series_analysis.py +++ b/stereo/algorithm/time_series_analysis.py @@ -37,7 +37,7 @@ def main( :param p_val_combination: p_value combination method to use, choosing from ['fisher', 'mean', 'FDR'] :param cluster_number: number of cluster - The parameters below are all key word arguments and only for `other` `run_method`. + All the parameters below are key word arguments and only for `other` `run_method`. :param spatial_weight: the weight to combine spatial feature :param n_spatial_feature: n top features to combine of spatial feature diff --git a/stereo/common.py b/stereo/common.py index 6ececa1b..7276af77 100644 --- a/stereo/common.py +++ b/stereo/common.py @@ -8,4 +8,4 @@ """ # version -version = '0.14.0b1' +version = '1.0.0' diff --git a/stereo/core/cell.py b/stereo/core/cell.py index 7f341b82..b142b759 100644 --- a/stereo/core/cell.py +++ b/stereo/core/cell.py @@ -232,7 +232,7 @@ def cell_name(self) -> np.ndarray: :return: cell name """ - return self.__based_ann_data.obs_names.values.astype(str) + return self.__based_ann_data.obs_names.values.astype('U') @cell_name.setter def cell_name(self, name: np.ndarray): @@ -244,7 +244,10 @@ def cell_name(self, name: np.ndarray): """ if not isinstance(name, np.ndarray): raise TypeError('cell name must be a np.ndarray object.') - self.__based_ann_data._inplace_subset_obs(name) + if name.size != self.__based_ann_data.n_obs: + raise ValueError(f'The length of cell names must be {self.__based_ann_data.n_obs}, but now is {name.size}') + self.__based_ann_data.obs_names = name + # self.__based_ann_data._inplace_subset_obs(name) @property def total_counts(self): diff --git a/stereo/core/gene.py b/stereo/core/gene.py index 608367cd..c3a64df2 100644 --- a/stereo/core/gene.py +++ b/stereo/core/gene.py @@ -164,7 +164,7 @@ def gene_name(self) -> np.ndarray: :return: genes name. """ - return self.__based_ann_data.var_names.values.astype(str) + return self.__based_ann_data.var_names.values.astype('U') @gene_name.setter def gene_name(self, name: np.ndarray): @@ -176,7 +176,10 @@ def gene_name(self, name: np.ndarray): """ if not isinstance(name, np.ndarray): raise TypeError('gene name must be a np.ndarray object.') - self.__based_ann_data._inplace_subset_var(name) + if name.size != self.__based_ann_data.n_vars: + raise ValueError(f'The length of gene names must be {self.__based_ann_data.n_vars}, but now is {name.size}') + self.__based_ann_data.var_names = name + # self.__based_ann_data._inplace_subset_var(name) def to_df(self): return self.__based_ann_data.var diff --git a/stereo/core/ms_pipeline.py b/stereo/core/ms_pipeline.py index 6c444f13..8df651d7 100644 --- a/stereo/core/ms_pipeline.py +++ b/stereo/core/ms_pipeline.py @@ -24,6 +24,8 @@ def __init__(self, _ms_data): super().__init__() self.ms_data = _ms_data self._result = dict() + self._key_record = dict() + # self._scope_data = dict() @property def result(self): @@ -32,6 +34,18 @@ def result(self): @result.setter def result(self, new_result): self._result = new_result + + @property + def key_record(self): + return self._key_record + + @key_record.setter + def key_record(self, key_record): + self._key_record = key_record + + # @property + # def scope_data(self): + # return self._scope_data def _use_integrate_method(self, item, *args, **kwargs): if "mode" in kwargs: @@ -51,6 +65,10 @@ def _use_integrate_method(self, item, *args, **kwargs): ms_data_view = self.ms_data[scope] if not ms_data_view.merged_data: ms_data_view.integrate(result=self.ms_data.tl.result) + + # key_name = "scope_[" + ",".join( + # [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]" + # self._scope_data[key_name] = self.ms_data._merged_data def callback_func(key, value): key_name = "scope_[" + ",".join( @@ -84,6 +102,15 @@ def contain_method(item): ms_data_view._merged_data.tl.result.contain_method = contain_method + def reset_key_record(key, res_key): + key_name = "scope_[" + ",".join( + [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]" + + ms_data_view._merged_data.tl._reset_key_record(key, res_key) + self._key_record[key_name] = ms_data_view._merged_data.tl.key_record + + ms_data_view._merged_data.tl.reset_key_record = reset_key_record + new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None) if new_attr is None: if self.__class__.ATTR_NAME == "tl": diff --git a/stereo/core/result.py b/stereo/core/result.py index 2d2210eb..7243d7f2 100644 --- a/stereo/core/result.py +++ b/stereo/core/result.py @@ -10,7 +10,7 @@ class _BaseResult(object): 'phenograph_from_bins', 'annotation_from_bins', 'celltype', 'cell_type' } CONNECTIVITY_NAMES = {'neighbors'} - REDUCE_NAMES = {'umap', 'pca', 'tsne'} + REDUCE_NAMES = {'umap', 'pca', 'tsne', 'correct'} HVG_NAMES = {'highly_variable_genes', 'hvg', 'highly_variable'} MARKER_GENES_NAMES = {'marker_genes', 'rank_genes_groups'} @@ -297,8 +297,8 @@ def __getitem__(self, name): return self.__based_ann_data.uns[name] elif name.startswith('gene_exp_'): return self.__based_ann_data.uns[name] - elif name.startswith('regulatory_network_inference'): - return self.__based_ann_data.uns[name] + # elif name.startswith('regulatory_network_inference'): + # return self.__based_ann_data.uns[name] obsm_obj = self.__based_ann_data.obsm.get(f'X_{name}', None) if obsm_obj is not None: @@ -349,11 +349,11 @@ def __setitem__(self, key, value): if not key.startswith('gene_exp_') and like_name in key and self._real_set_item(name_type, key, value): return - if key == "regulatory_network_inference": - self.__based_ann_data.uns[f'{key}_regulons'] = value['regulons'] - self.__based_ann_data.uns[f'{key}_auc_matrix'] = value['auc_matrix'] - self.__based_ann_data.uns[f'{key}_adjacencies'] = value['adjacencies'] - return + # if key == "regulatory_network_inference": + # self.__based_ann_data.uns[f'{key}_regulons'] = value['regulons'] + # self.__based_ann_data.uns[f'{key}_auc_matrix'] = value['auc_matrix'] + # self.__based_ann_data.uns[f'{key}_adjacencies'] = value['adjacencies'] + # return if type(value) is pd.DataFrame: if 'bins' in value.columns.values and 'group' in value.columns.values: diff --git a/stereo/core/st_pipeline.py b/stereo/core/st_pipeline.py index d68f92bc..069e70af 100644 --- a/stereo/core/st_pipeline.py +++ b/stereo/core/st_pipeline.py @@ -59,6 +59,7 @@ def __init__(self, data: Union[StereoExpData, AnnBasedStereoExpData]): self.result = Result(data) self._raw: Union[StereoExpData, AnnBasedStereoExpData] = None self.key_record = {'hvg': [], 'pca': [], 'neighbors': [], 'umap': [], 'cluster': [], 'marker_genes': []} + self.reset_key_record = self._reset_key_record def __getattr__(self, item): dict_attr = self.__dict__.get(item, None) @@ -122,7 +123,7 @@ def raw_checkpoint(self): """ self.raw = self.data - def reset_key_record(self, key, res_key): + def _reset_key_record(self, key, res_key): """ reset key and coordinated res_key in key_record. :param key: @@ -1310,16 +1311,37 @@ def annotation( assert cluster_res_key in self.result, f'{cluster_res_key} is not in the result, please check and run the ' \ f'cluster func.' - df = copy.deepcopy(self.result[cluster_res_key]) - if isinstance(annotation_information, list): - df.group.cat.categories = annotation_information + # df = copy.deepcopy(self.result[cluster_res_key]) + # if isinstance(annotation_information, list): + # df.group.cat.categories = np.unique(annotation_information) + # elif isinstance(annotation_information, dict): + # new_annotation_list = [] + # for i in df.group.cat.categories: + # new_annotation_list.append(annotation_information[i]) + # df.group.cat.categories = new_annotation_list + + cluster_res: pd.DataFrame = self.result[cluster_res_key] + + if isinstance(annotation_information, (list, np.ndarray)) and len(annotation_information) != cluster_res['group'].cat.categories.size: + raise Exception(f"The length of annotation information is {len(annotation_information)}, \ + not equal to the categories of cluster result whoes lenght is {cluster_res['group'].cat.categories.size}.") + + if isinstance(annotation_information, (list, np.ndarray)): + new_categories = np.array(annotation_information, dtype='U') elif isinstance(annotation_information, dict): - new_annotation_list = [] - for i in df.group.cat.categories: - new_annotation_list.append(annotation_information[i]) - df.group.cat.categories = new_annotation_list - - self.result[res_key] = df + new_categories_list = [] + for i in cluster_res['group'].cat.categories: + new_categories_list.append(annotation_information[i]) + new_categories = np.array(new_categories_list, dtype='U') + else: + raise TypeError(f"The type of 'annotation_information' only supports list, ndarray or dict.") + + new_categories_values = new_categories[cluster_res['group'].cat.codes] + + self.result[res_key] = pd.DataFrame(data={ + 'bins': cluster_res['bins'], + 'group': pd.Series(new_categories_values, dtype='category') + }) key = 'cluster' self.reset_key_record(key, res_key) diff --git a/stereo/core/stereo_exp_data.py b/stereo/core/stereo_exp_data.py index d307deb7..6c71e083 100644 --- a/stereo/core/stereo_exp_data.py +++ b/stereo/core/stereo_exp_data.py @@ -601,6 +601,14 @@ def reset_position(self): self.position[idx] += self.position_min[bno] self.position_offset = None self.position_min = None + + def __add__(self, other): + from stereo.core.ms_data import MSData + if isinstance(other, StereoExpData): + ms_data = MSData([self, other]) + else: + raise TypeError + return ms_data class AnnBasedStereoExpData(StereoExpData): @@ -632,6 +640,7 @@ def __init__( if 'resolution' in self._ann_data.uns: self.attr = {'resolution': self._ann_data.uns['resolution']} + del self._ann_data.uns['resolution'] if bin_type is not None and 'bin_type' not in self._ann_data.uns: self._ann_data.uns['bin_type'] = bin_type @@ -758,10 +767,19 @@ def position_z(self, position_z: np.ndarray): @property def bin_type(self): return self._ann_data.uns.get('bin_type', 'bins') + + @bin_type.setter + def bin_type(self, bin_type): + self.bin_type_check(bin_type) + self._ann_data.uns['bin_type'] = bin_type @property def bin_size(self): return self._ann_data.uns.get('bin_size', 1) + + @bin_size.setter + def bin_size(self, bin_size): + self._ann_data.uns['bin_size'] = bin_size @property def sn(self): @@ -775,6 +793,18 @@ def sn(self): for _, row in sn_data.iterrows(): sn[row['batch']] = row['sn'] return sn + + @sn.setter + def sn(self, sn): + if isinstance(sn, str): + sn_list = [['-1', sn]] + elif isinstance(sn, dict): + sn_list = [] + for bno, sn in sn.items(): + sn_list.append([bno, sn]) + else: + raise TypeError(f'sn must be type of str or dict, but now is {type(sn)}') + self._ann_data.uns['sn'] = pd.DataFrame(sn_list, columns=['batch', 'sn']) def sub_by_index(self, cell_index=None, gene_index=None): if cell_index is not None: diff --git a/stereo/io/__init__.py b/stereo/io/__init__.py index cb5d8301..4c165416 100644 --- a/stereo/io/__init__.py +++ b/stereo/io/__init__.py @@ -16,7 +16,8 @@ stereo_to_anndata, read_gef_info, read_seurat_h5ad, - read_h5ad + read_h5ad, + read_h5ms ) from .writer import ( write, diff --git a/stereo/io/reader.py b/stereo/io/reader.py index d587c77d..2b580eb8 100644 --- a/stereo/io/reader.py +++ b/stereo/io/reader.py @@ -225,7 +225,7 @@ def read_stereo_h5ad( return data -def _read_stereo_h5ad_from_group(f, data: StereoExpData, use_raw, use_result, bin_type, bin_size): +def _read_stereo_h5ad_from_group(f, data: StereoExpData, use_raw, use_result, bin_type=None, bin_size=None): # read data data.bin_type = bin_type if bin_type is not None else 'bins' data.bin_size = bin_size if bin_size is not None else 1 @@ -305,7 +305,7 @@ def _read_stereo_h5_result(key_record: dict, data, f): # str to interval hvg_df['mean_bin'] = [to_interval(interval_string) for interval_string in hvg_df['mean_bin']] data.tl.result[res_key] = hvg_df - if analysis_key in ['pca', 'umap', 'totalVI']: + if analysis_key in ['pca', 'umap', 'totalVI', 'spatial_alignment_integration']: data.tl.result[res_key] = pd.DataFrame(h5ad.read_dataset(f[f'{res_key}@{analysis_key}'])) if analysis_key == 'neighbors': data.tl.result[res_key] = { @@ -379,7 +379,7 @@ def _read_stereo_h5_result(key_record: dict, data, f): data.tl.result[res_key][key] = ast.literal_eval(h5ad.read_dataset(f[full_key])) else: data.tl.result[res_key][key] = h5ad.read_group(f[full_key]) - if analysis_key in ['co_occurrence', 'res_totalVI']: + if analysis_key in ['co_occurrence']: data.tl.result[res_key] = {} for full_key in f.keys(): if not full_key.endswith(analysis_key): @@ -417,6 +417,7 @@ def read_h5ms(file_path, use_raw=True, use_result=True): elif k == 'mss': for key in f['mss'].keys(): data = StereoExpData() + data.tl.result = {} h5ad.read_key_record(f['mss'][key]['key_record'], data.tl.key_record) _read_stereo_h5_result(data.tl.key_record, data, f['mss'][key]) result[key] = data.tl.result @@ -841,7 +842,7 @@ def stereo_to_anndata( adata.uns['sct_top_features'] = list(one_index_data['top_features']) adata.uns['sct_cellname'] = list(one_index_data['umi_cells'].astype('str')) adata.uns['sct_genename'] = list(one_index_data['umi_genes']) - elif key in ['pca', 'umap', 'tsne', 'totalVI']: + elif key in ['pca', 'umap', 'tsne', 'totalVI', 'spatial_alignment_integration']: # pca :we do not keep variance and PCs(for varm which will be into feature.finding in pca of seurat.) res_key = data.tl.key_record[key][-1] sc_key = f'X_{key}' @@ -868,19 +869,21 @@ def stereo_to_anndata( adata.obs[res_key] = pd.DataFrame(data.tl.result[res_key]['group'].values, index=cell_name_index) elif key in ('gene_exp_cluster', 'cell_cell_communication'): for res_key in data.tl.key_record[key]: - logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{key}@{res_key}']") - adata.uns[f"{key}@{res_key}"] = data.tl.result[res_key] - elif key == 'regulatory_network_inference': - for res_key in data.tl.key_record[key]: - logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{res_key}'] .") - regulon_key = f'{res_key}_regulons' - res_key_data = data.tl.result[res_key] - adata.uns[regulon_key] = res_key_data['regulons'] - auc_matrix_key = f'{res_key}_auc_matrix' - adata.uns[auc_matrix_key] = res_key_data['auc_matrix'] - adjacencies_key = f'{res_key}_adjacencies' - adata.uns[adjacencies_key] = res_key_data['adjacencies'] - elif key == 'co_occurrence': + # logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{key}@{res_key}']") + # adata.uns[f"{key}@{res_key}"] = data.tl.result[res_key] + logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{res_key}']") + adata.uns[res_key] = data.tl.result[res_key] + # elif key == 'regulatory_network_inference': + # for res_key in data.tl.key_record[key]: + # logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{res_key}'] .") + # regulon_key = f'{res_key}_regulons' + # res_key_data = data.tl.result[res_key] + # adata.uns[regulon_key] = res_key_data['regulons'] + # auc_matrix_key = f'{res_key}_auc_matrix' + # adata.uns[auc_matrix_key] = res_key_data['auc_matrix'] + # adjacencies_key = f'{res_key}_adjacencies' + # adata.uns[adjacencies_key] = res_key_data['adjacencies'] + elif key in ('co_occurrence', 'regulatory_network_inference'): for res_key in data.tl.key_record[key]: logger.info(f"Adding data.tl.result['{res_key}'] into adata.uns['{res_key}'] .") adata.uns[res_key] = data.tl.result[res_key] diff --git a/stereo/io/writer.py b/stereo/io/writer.py index fa92cc59..9e13b8ab 100644 --- a/stereo/io/writer.py +++ b/stereo/io/writer.py @@ -171,7 +171,7 @@ def _write_one_h5ad_result(data, f, key_record): if 'mean_bin' in hvg_df.columns: hvg_df.mean_bin = [str(interval) for interval in data.tl.result[res_key].mean_bin] h5ad.write(hvg_df, f, f'{res_key}@hvg') # -> dataframe - if analysis_key in ['pca', 'umap', 'totalVI']: + if analysis_key in ['pca', 'umap', 'totalVI', 'spatial_alignment_integration']: h5ad.write(data.tl.result[res_key].values, f, f'{res_key}@{analysis_key}') # -> array if analysis_key == 'neighbors': for neighbor_key, value in data.tl.result[res_key].items(): @@ -266,15 +266,10 @@ def write_h5ms(ms_data, output: str): h5ad.write(ms_data.relationship, f, 'relationship') if ms_data.tl.result: mss_f = f.create_group('mss') - for key, value in ms_data.tl.result.items(): + for key in ms_data.tl.result.keys(): data = StereoExpData() data.tl.result = ms_data.tl.result[key] - # TODO only supported default name temporarily - for r_key in data.tl.result.keys(): - o_key = r_key - if r_key in {'leiden', 'louvain', 'phenograph', 'annotation'}: - o_key = 'cluster' - data.tl.reset_key_record(o_key, r_key) + data.tl.key_record = ms_data.tl.key_record[key] mss_f.create_group(key) _write_one_h5ad_result(data, mss_f[key], data.tl.key_record) diff --git a/stereo/plots/plot_time_series.py b/stereo/plots/plot_time_series.py index af3a373f..ed7a0849 100644 --- a/stereo/plots/plot_time_series.py +++ b/stereo/plots/plot_time_series.py @@ -69,10 +69,13 @@ def boxplot_transit_gene(self, branch2exp = defaultdict(dict) stereo_exp_data = self.stereo_exp_data for x in branch: - cell_list = stereo_exp_data.cells.to_df().loc[stereo_exp_data.cells[use_col] == x, :].index - tmp_exp_data = stereo_exp_data.sub_by_name(cell_name=cell_list) + # cell_list = stereo_exp_data.cells.to_df().loc[stereo_exp_data.cells[use_col] == x, :].index + # tmp_exp_data = stereo_exp_data.sub_by_name(cell_name=cell_list) + cell_flag = (stereo_exp_data.cells[use_col] == x).to_numpy() + tmp_exp_data = stereo_exp_data.exp_matrix[cell_flag] for gene in genes: - branch2exp[gene][x] = tmp_exp_data.sub_by_name(gene_name=[gene]).exp_matrix.toarray().flatten() + # branch2exp[gene][x] = tmp_exp_data.sub_by_name(gene_name=[gene]).exp_matrix.toarray().flatten() + branch2exp[gene][x] = tmp_exp_data[:, stereo_exp_data.gene_names == gene].toarray().flatten() fig = plt.figure(figsize=(4 * len(genes), 6)) ax = fig.subplots(1, len(genes)) diff --git a/stereo/plots/vt3d_browser/example.py b/stereo/plots/vt3d_browser/example.py index e737b46f..0256da53 100644 --- a/stereo/plots/vt3d_browser/example.py +++ b/stereo/plots/vt3d_browser/example.py @@ -38,6 +38,16 @@ def start_vt3d_browser( th.setDaemon(True) th.start() + # kwargs={ + # 'meshes': meshes, + # 'cluster_label': cluster_res_key, + # 'paga_key': paga_res_key, + # 'ccc_key': ccc_res_key, + # 'grn_key': grn_res_key, + # 'port': port + # } + # launch(self.stereo_exp_data, **kwargs) + def display_3d_mesh(self, width=1400, height=1200, ip='127.0.0.1', port=7654): import IPython sleep(5)