Skip to content

Commit

Permalink
update some feature and fix some bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-genomics-cn committed Oct 26, 2023
1 parent 67e5d42 commit 5df4bb3
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
5 changes: 3 additions & 2 deletions stereo/algorithm/community_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def params_init(
elif 'spatial_stereoseq' in slice._ann_data.obsm:
slice._ann_data.obsm['spatial'] = np.array(slice._ann_data.obsm['spatial_stereoseq'].copy())
# annotation data must be of string type
slice._ann_data.obs[annotation] = slice._ann_data.obs[annotation].astype('str')
# slice._ann_data.obs[annotation] = slice._ann_data.obs[annotation].astype('str')
# create a set of existing cell types in all slices
self.cell_types = self.cell_types.union(set(slice._ann_data.obs[annotation].unique()))
# if any of the samples lacks the cell type palette, set the flag
Expand Down Expand Up @@ -232,6 +232,7 @@ def _main(self, slices, annotation="annotation", **kwargs):
self.slices[slice_id]._ann_data.obs.loc[
algo.adata.obs[f'tissue_{algo.method_key}'].index, 'cell_communities'] = algo.adata.obs[
f'tissue_{algo.method_key}']
self.slices[slice_id]._ann_data.obs['cell_communities'].fillna('unknown', inplace=True)

# save anndata objects for further use
if self.params['save_adata']:
Expand Down Expand Up @@ -306,7 +307,7 @@ def cluster(self, merged_tissue): # TODO, merged_tissue da bude AnnBasedStereoE
resolution=self.params['resolution'])
merged_tissue._ann_data.obs['leiden'] = merged_tissue._ann_data.obs['leiden'].astype('int')
merged_tissue._ann_data.obs['leiden'] -= 1
merged_tissue._ann_data.obs['leiden'] = merged_tissue._ann_data.obs['leiden'].astype('str')
merged_tissue._ann_data.obs['leiden'] = merged_tissue._ann_data.obs['leiden'].astype('U')
merged_tissue._ann_data.obs['leiden'] = merged_tissue._ann_data.obs['leiden'].astype('category')
elif self.params['cluster_algo'] == 'spectral':
merged_tissue._ann_data.obsm['X_pca_dummy'] = merged_tissue._ann_data.X
Expand Down
8 changes: 7 additions & 1 deletion stereo/algorithm/get_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def main(
# adata_result = adata_full[(list(result_target_sender.index) + list(result_target_sender.columns)), :]
cell_list = list(result_target_sender.index) + list(result_target_sender.columns)
data_result = filter_cells(data_full, cell_list=cell_list, inplace=inplace)
if not inplace:
data_result.tl.result.set_item_callback = None

for res_key in data_result.tl.key_record['pca']:
data_result.tl.result[res_key] = data_result.tl.result[res_key][
np.isin(data_full.cell_names, cell_list)].copy()
Expand All @@ -68,5 +71,8 @@ def main(
index=data_result.cell_names,
dtype='category'
)


if not inplace:
data_result.tl.result.contain_method = None
data_result.tl.result.get_item_method = None
return data_result
4 changes: 2 additions & 2 deletions stereo/core/ms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,8 @@ def to_integrate(
res = stereo_exp_data.cells._obs[item[idx]]
sample_idx = self._names.index(scope_names[idx])
new_index = res.index.astype('str') + f'-{sample_idx}'
res.index = new_index
self.merged_data.cells._obs.loc[new_index, res_key] = res
# res.index = new_index
self.merged_data.cells._obs.loc[new_index, res_key] = res.to_numpy()
elif type == 'var':
raise NotImplementedError
else:
Expand Down
14 changes: 14 additions & 0 deletions stereo/core/stereo_exp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
self._position = position
self._position_z = position_z
self._position_offset = None
self._position_min = None
self._bin_type = bin_type
self._bin_size = bin_size
self._tl = None
Expand Down Expand Up @@ -414,6 +415,14 @@ def position_offset(self):
@position_offset.setter
def position_offset(self, position_offset):
self._position_offset = position_offset

@property
def position_min(self):
return self._position_min

@position_min.setter
def position_min(self, position_min):
self._position_min = position_min

@property
def offset_x(self):
Expand Down Expand Up @@ -589,7 +598,9 @@ def reset_position(self):
for bno in batches:
idx = np.where(self.cells.batch == bno)[0]
self.position[idx] -= self.position_offset[bno]
self.position[idx] += self.position_min[bno]
self.position_offset = None
self.position_min = None


class AnnBasedStereoExpData(StereoExpData):
Expand All @@ -600,6 +611,7 @@ def __init__(
based_ann_data: anndata.AnnData = None,
bin_type: str = None,
bin_size: int = None,
spatial_key: Union[str, list, np.ndarray] = 'spatial',
*args,
**kwargs
):
Expand Down Expand Up @@ -637,6 +649,8 @@ def __init__(

if self._ann_data.raw:
self._tl._raw = AnnBasedStereoExpData(based_ann_data=self._ann_data.raw.to_adata())

self._spatial_key = spatial_key

def __str__(self):
return str(self._ann_data)
Expand Down
4 changes: 2 additions & 2 deletions stereo/plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,8 @@ def cluster_scatter(
group_list = res['group'].to_numpy()
n = np.unique(group_list).size
palette = stereo_conf.get_colors(colors, n=n)
x = self.data.position[:, 0].astype(int)
y = self.data.position[:, 1].astype(int)
x = self.data.position[:, 0]
y = self.data.position[:, 1]
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()
boundary = [x_min, x_max, y_min, y_max]
Expand Down
2 changes: 1 addition & 1 deletion stereo/plots/vt3d_browser/stereopy_3D_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def get_anno(self):
"""
xyz = np.concatenate([self._data.position, self._data.position_z], axis=1)
df = pd.DataFrame(data=xyz, columns=['x', 'y', 'z'])
df = df.astype(int) # force convert to int to save space
# df = df.astype(int) # force convert to int to save space
if self._cluster_label in self._data.cells._obs.columns:
df['anno'] = self._data.cells._obs[self._cluster_label].to_numpy()
else:
Expand Down

0 comments on commit 5df4bb3

Please sign in to comment.