diff --git a/stereo/algorithm/get_niche.py b/stereo/algorithm/get_niche.py index 05fa7a7e..81188422 100644 --- a/stereo/algorithm/get_niche.py +++ b/stereo/algorithm/get_niche.py @@ -117,11 +117,11 @@ def main( else: raise InvalidNicheMethod(method) - data_result = filter_cells(data_full, cell_list=cell_list, inplace=inplace) - if filter_raw and data_result.raw is not None: - filter_cells(data_result.raw, cell_list=cell_list, inplace=True) - if isinstance(data_result, AnnBasedStereoExpData): - data_result.adata.raw = data_result.raw.adata + data_result = filter_cells(data_full, cell_list=cell_list, inplace=inplace, filter_raw=filter_raw) + # if filter_raw and data_result.raw is not None: + # filter_cells(data_result.raw, cell_list=cell_list, inplace=True) + # if isinstance(data_result, AnnBasedStereoExpData): + # data_result.adata.raw = data_result.raw.adata return data_result diff --git a/stereo/plots/plot_collection.py b/stereo/plots/plot_collection.py index d72bf993..82ba1223 100644 --- a/stereo/plots/plot_collection.py +++ b/stereo/plots/plot_collection.py @@ -422,8 +422,10 @@ def spatial_scatter( from .scatter import multi_scatter x = self.data.position[:, 0] y = self.data.position[:, 1] - x_min, x_max = int(x.min()), int(x.max()) - y_min, y_max = int(y.min()), int(y.max()) + # x_min, x_max = int(x.min()), int(x.max()) + # y_min, y_max = int(y.min()), int(y.max()) + x_min, x_max = x.min(), x.max() + y_min, y_max = y.min(), y.max() boundary = [x_min, x_max, y_min, y_max] marker = 's' diff --git a/stereo/plots/scatter.py b/stereo/plots/scatter.py index 1c3e0517..9ca78807 100644 --- a/stereo/plots/scatter.py +++ b/stereo/plots/scatter.py @@ -38,8 +38,10 @@ def _plot_scale( boundary: list ): if boundary is None: - min_x, max_x = np.min(x).astype(int), np.max(x).astype(int) - min_y, max_y = np.min(y).astype(int), np.max(y).astype(int) + # min_x, max_x = np.min(x).astype(int), np.max(x).astype(int) + # min_y, max_y = np.min(y).astype(int), np.max(y).astype(int) + min_x, max_x = np.min(x), np.max(x) + min_y, max_y = np.min(y), np.max(y) else: min_x, max_x, min_y, max_y = boundary @@ -49,7 +51,8 @@ def _plot_scale( if plotting_scale_width is None: data_width = max_x - min_x + 1 data_height = max_y - min_y + 1 - plotting_scale_width = max(np.ceil(min(data_width, data_height) / 5), 10) + # plotting_scale_width = max(np.ceil(min(data_width, data_height) / 5), 10) + plotting_scale_width = np.ceil(min(data_width, data_height) / 5) highest_num = plotting_scale_width // (10 ** np.log10(plotting_scale_width).astype(int)) plotting_scale_width = highest_num * (10 ** np.log10(plotting_scale_width).astype(int)) @@ -59,7 +62,8 @@ def _plot_scale( bin_count = plotting_scale_width // data_bin_offset # horizontal_end_x = horizontal_start_x + (bin_count - 1) * data_bin_offset - horizontal_end_x = horizontal_start_x + plotting_scale_width - 1 + # horizontal_end_x = horizontal_start_x + plotting_scale_width - 1 + horizontal_end_x = horizontal_start_x + plotting_scale_width horizontal_text_location_x = horizontal_start_x + plotting_scale_width / 2 vertical_x_location = min_x - plotting_scale_height * 2 @@ -70,7 +74,8 @@ def _plot_scale( horizontal_y_location = min_y - plotting_scale_height * 2 vertical_start_y = min_y # vertical_end_y = vertical_start_y + (bin_count - 1) * data_bin_offset - vertical_end_y = vertical_start_y + plotting_scale_width - 1 + # vertical_end_y = vertical_start_y + plotting_scale_width - 1 + vertical_end_y = vertical_start_y + plotting_scale_width vertical_text_location_y = vertical_start_y + plotting_scale_width / 2 vertices = [ (horizontal_start_x, horizontal_y_location - plotting_scale_height), @@ -85,7 +90,8 @@ def _plot_scale( horizontal_y_location = max_y + plotting_scale_height * 2 vertical_start_y = max_y # vertical_end_y = vertical_start_y - (bin_count - 1) * data_bin_offset - vertical_end_y = vertical_start_y - plotting_scale_width + 1 + # vertical_end_y = vertical_start_y - plotting_scale_width + 1 + vertical_end_y = vertical_start_y - plotting_scale_width vertical_text_location_y = vertical_start_y - plotting_scale_width / 2 vertices = [ (horizontal_start_x, horizontal_y_location + plotting_scale_height), diff --git a/stereo/preprocess/filter.py b/stereo/preprocess/filter.py index 43042f2f..286dfd8c 100644 --- a/stereo/preprocess/filter.py +++ b/stereo/preprocess/filter.py @@ -40,7 +40,7 @@ def filter_cells( use_raw=True, layer=None, inplace=True - ): +) -> StereoExpData: """ filter cells based on numbers of genes expressed. @@ -99,7 +99,7 @@ def filter_genes( use_raw=True, layer=None, inplace=True -): +) -> StereoExpData: """ filter genes based on the numbers of cells. @@ -156,7 +156,7 @@ def filter_coordinates( max_y=None, filter_raw=True, inplace=True -): +) -> StereoExpData: """ filter cells based on the coordinates of cells. @@ -187,36 +187,6 @@ def filter_coordinates( data.sub_by_index(cell_index=obs_subset, filter_raw=filter_raw) return data - -# def filter_by_clusters( -# data: StereoExpData, -# cluster_res: pd.DataFrame, -# groups: Union[str, np.ndarray, List[str]], -# excluded: bool = False, -# inplace: bool = True -# ) -> Tuple[StereoExpData, pd.DataFrame]: -# """_summary_ - -# :param data: StereoExpData object. -# :param cluster_res: clustering result. -# :param groups: the groups in clustering result which will be filtered. -# :param inplace: whether inplace the original data or return a new data. -# :param excluded: bool type. -# :return: StereoExpData object -# """ -# data = data if inplace else copy.deepcopy(data) -# all_groups = cluster_res['group'] -# if isinstance(groups, str): -# groups = [groups] -# is_in_bool = all_groups.isin(groups).to_numpy() -# if excluded: -# is_in_bool = ~is_in_bool -# data.sub_by_index(cell_index=is_in_bool) -# cluster_res = cluster_res[is_in_bool].copy() -# cluster_res['group'] = cluster_res['group'].to_numpy() -# cluster_res['group'] = cluster_res['group'].astype('category') -# return data, cluster_res - def filter_by_clusters( data: StereoExpData, cluster_res_key: str, @@ -224,7 +194,7 @@ def filter_by_clusters( excluded: bool = False, filter_raw: bool = True, inplace: bool = False -) -> Tuple[StereoExpData, pd.DataFrame]: +) -> StereoExpData: """_summary_ :param data: StereoExpData object.