Skip to content

Commit

Permalink
do some updates
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Dec 9, 2024
1 parent 78b14e8 commit 0f3a322
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 47 deletions.
10 changes: 5 additions & 5 deletions stereo/algorithm/get_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def main(
else:
raise InvalidNicheMethod(method)

data_result = filter_cells(data_full, cell_list=cell_list, inplace=inplace)
if filter_raw and data_result.raw is not None:
filter_cells(data_result.raw, cell_list=cell_list, inplace=True)
if isinstance(data_result, AnnBasedStereoExpData):
data_result.adata.raw = data_result.raw.adata
data_result = filter_cells(data_full, cell_list=cell_list, inplace=inplace, filter_raw=filter_raw)
# if filter_raw and data_result.raw is not None:
# filter_cells(data_result.raw, cell_list=cell_list, inplace=True)
# if isinstance(data_result, AnnBasedStereoExpData):
# data_result.adata.raw = data_result.raw.adata


return data_result
6 changes: 4 additions & 2 deletions stereo/plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,10 @@ def spatial_scatter(
from .scatter import multi_scatter
x = self.data.position[:, 0]
y = self.data.position[:, 1]
x_min, x_max = int(x.min()), int(x.max())
y_min, y_max = int(y.min()), int(y.max())
# x_min, x_max = int(x.min()), int(x.max())
# y_min, y_max = int(y.min()), int(y.max())
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()
boundary = [x_min, x_max, y_min, y_max]
marker = 's'

Expand Down
18 changes: 12 additions & 6 deletions stereo/plots/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def _plot_scale(
boundary: list
):
if boundary is None:
min_x, max_x = np.min(x).astype(int), np.max(x).astype(int)
min_y, max_y = np.min(y).astype(int), np.max(y).astype(int)
# min_x, max_x = np.min(x).astype(int), np.max(x).astype(int)
# min_y, max_y = np.min(y).astype(int), np.max(y).astype(int)
min_x, max_x = np.min(x), np.max(x)
min_y, max_y = np.min(y), np.max(y)
else:
min_x, max_x, min_y, max_y = boundary

Expand All @@ -49,7 +51,8 @@ def _plot_scale(
if plotting_scale_width is None:
data_width = max_x - min_x + 1
data_height = max_y - min_y + 1
plotting_scale_width = max(np.ceil(min(data_width, data_height) / 5), 10)
# plotting_scale_width = max(np.ceil(min(data_width, data_height) / 5), 10)
plotting_scale_width = np.ceil(min(data_width, data_height) / 5)
highest_num = plotting_scale_width // (10 ** np.log10(plotting_scale_width).astype(int))
plotting_scale_width = highest_num * (10 ** np.log10(plotting_scale_width).astype(int))

Expand All @@ -59,7 +62,8 @@ def _plot_scale(
bin_count = plotting_scale_width // data_bin_offset

# horizontal_end_x = horizontal_start_x + (bin_count - 1) * data_bin_offset
horizontal_end_x = horizontal_start_x + plotting_scale_width - 1
# horizontal_end_x = horizontal_start_x + plotting_scale_width - 1
horizontal_end_x = horizontal_start_x + plotting_scale_width
horizontal_text_location_x = horizontal_start_x + plotting_scale_width / 2

vertical_x_location = min_x - plotting_scale_height * 2
Expand All @@ -70,7 +74,8 @@ def _plot_scale(
horizontal_y_location = min_y - plotting_scale_height * 2
vertical_start_y = min_y
# vertical_end_y = vertical_start_y + (bin_count - 1) * data_bin_offset
vertical_end_y = vertical_start_y + plotting_scale_width - 1
# vertical_end_y = vertical_start_y + plotting_scale_width - 1
vertical_end_y = vertical_start_y + plotting_scale_width
vertical_text_location_y = vertical_start_y + plotting_scale_width / 2
vertices = [
(horizontal_start_x, horizontal_y_location - plotting_scale_height),
Expand All @@ -85,7 +90,8 @@ def _plot_scale(
horizontal_y_location = max_y + plotting_scale_height * 2
vertical_start_y = max_y
# vertical_end_y = vertical_start_y - (bin_count - 1) * data_bin_offset
vertical_end_y = vertical_start_y - plotting_scale_width + 1
# vertical_end_y = vertical_start_y - plotting_scale_width + 1
vertical_end_y = vertical_start_y - plotting_scale_width
vertical_text_location_y = vertical_start_y - plotting_scale_width / 2
vertices = [
(horizontal_start_x, horizontal_y_location + plotting_scale_height),
Expand Down
38 changes: 4 additions & 34 deletions stereo/preprocess/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def filter_cells(
use_raw=True,
layer=None,
inplace=True
):
) -> StereoExpData:
"""
filter cells based on numbers of genes expressed.
Expand Down Expand Up @@ -99,7 +99,7 @@ def filter_genes(
use_raw=True,
layer=None,
inplace=True
):
) -> StereoExpData:
"""
filter genes based on the numbers of cells.
Expand Down Expand Up @@ -156,7 +156,7 @@ def filter_coordinates(
max_y=None,
filter_raw=True,
inplace=True
):
) -> StereoExpData:
"""
filter cells based on the coordinates of cells.
Expand Down Expand Up @@ -187,44 +187,14 @@ def filter_coordinates(
data.sub_by_index(cell_index=obs_subset, filter_raw=filter_raw)
return data


# def filter_by_clusters(
# data: StereoExpData,
# cluster_res: pd.DataFrame,
# groups: Union[str, np.ndarray, List[str]],
# excluded: bool = False,
# inplace: bool = True
# ) -> Tuple[StereoExpData, pd.DataFrame]:
# """_summary_

# :param data: StereoExpData object.
# :param cluster_res: clustering result.
# :param groups: the groups in clustering result which will be filtered.
# :param inplace: whether inplace the original data or return a new data.
# :param excluded: bool type.
# :return: StereoExpData object
# """
# data = data if inplace else copy.deepcopy(data)
# all_groups = cluster_res['group']
# if isinstance(groups, str):
# groups = [groups]
# is_in_bool = all_groups.isin(groups).to_numpy()
# if excluded:
# is_in_bool = ~is_in_bool
# data.sub_by_index(cell_index=is_in_bool)
# cluster_res = cluster_res[is_in_bool].copy()
# cluster_res['group'] = cluster_res['group'].to_numpy()
# cluster_res['group'] = cluster_res['group'].astype('category')
# return data, cluster_res

def filter_by_clusters(
data: StereoExpData,
cluster_res_key: str,
groups: Union[str, np.ndarray, List[str]],
excluded: bool = False,
filter_raw: bool = True,
inplace: bool = False
) -> Tuple[StereoExpData, pd.DataFrame]:
) -> StereoExpData:
"""_summary_
:param data: StereoExpData object.
Expand Down

0 comments on commit 0f3a322

Please sign in to comment.