Skip to content

Commit

Permalink
do some optimization and updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-genomics-cn committed Nov 20, 2023
1 parent 2d40ccc commit 2f4dcd5
Show file tree
Hide file tree
Showing 18 changed files with 234 additions and 87 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def get_version(rel_path: str) -> str:
long_description_content_type="text/markdown",
url='https://github.com/STOmics/Stereopy',
author='STOmics',
author_email='xujunhao@genomics.cn',
author_email='tanliwei@stomics.tech',
license='MIT License',
python_requires='>=3.8,<3.9',
install_requires=[
l.strip() for l in Path('requirements.txt').read_text('utf-8').splitlines()
Expand Down
5 changes: 4 additions & 1 deletion stereo/algorithm/cell_cell_communication/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ def _prepare_data(self, cluster_res_key):
cluster.rename({'group': 'cell_type'}, axis=1, inplace=True)
cluster.set_index('bins', drop=True, inplace=True)
cluster.index.name = 'cell'
data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T.toarray())
if self.stereo_exp_data.issparse():
data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T.toarray())
else:
data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T)
data.columns = self.stereo_exp_data.cell_names.astype(str)
data.index = self.stereo_exp_data.gene_names
return data, cluster
Expand Down
15 changes: 11 additions & 4 deletions stereo/algorithm/co_occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from stereo.core.stereo_exp_data import AnnBasedStereoExpData
# module in self project
from stereo.core.stereo_exp_data import StereoExpData
from stereo.core.ms_data import MSData


# ----------------------------------------------#
Expand Down Expand Up @@ -286,7 +287,7 @@ def co_occurrence(
return ret

@staticmethod
def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'):
def ms_co_occur_integrate(ms_data: MSData, scope, use_col, res_key='co_occurrence'):
from collections import Counter
if use_col not in ms_data.obs:
tmp_list = []
Expand All @@ -296,11 +297,13 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'):
ms_data.obs[use_col] = ms_data.obs[use_col].astype('category')

slice_groups = scope.split('|')
slice_index = []
if len(slice_groups) == 1:
slices = slice_groups[0].split(",")
ct_count = {}
for x in slices:
ct_count[x] = dict(Counter(ms_data[x].cells[use_col]))
slice_index.append(ms_data.names.index(x))

ct_count = pd.DataFrame(ct_count)
ct_ratio = ct_count.div(ct_count.sum(axis=1), axis=0)
Expand All @@ -318,6 +321,7 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'):
ct_count = {}
for x in slices:
ct_count[x] = dict(Counter(ms_data[x].cells[use_col]))
slice_index.append(ms_data.names.index(x))

ct_count = pd.DataFrame(ct_count)
ct_ratio = ct_count.div(ct_count.sum(axis=1), axis=0)
Expand All @@ -333,7 +337,10 @@ def ms_co_occur_integrate(ms_data, scope, use_col, res_key='co_occurrence'):
merge_co_occur_ret = {ct: ret[0][ct] - ret[1][ct] for ct in merge_co_occur_ret}

else:
print('co-occurrence only compare case and control on two groups')
merge_co_occur_ret = None
raise Exception('co-occurrence only compare case and control on two groups')
# merge_co_occur_ret = None

return merge_co_occur_ret
# return merge_co_occur_ret
slice_index = np.unique(slice_index)
scope_key = "scope_[" + ",".join([str(i) for i in slice_index]) + "]"
ms_data.tl.result[scope_key][res_key] = merge_co_occur_ret
78 changes: 58 additions & 20 deletions stereo/algorithm/regulatory_network_inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ def main(self,
seed: int = None,
cache: bool = False,
cache_res_key: str = 'regulatory_network_inference',
save: bool = True,
save_regulons: bool = True,
save_loom: bool = False,
fn_prefix: str = None,
method: str = 'grnboost',
ThreeD_slice: bool = False,
prune_kwargs: dict = {}
prune_kwargs: dict = {},
hotspot_kwargs: dict = {},
use_raw: bool = True
):
"""
Enables researchers to infer transcription factors (TFs) and gene regulatory networks.
Expand All @@ -76,14 +80,27 @@ def main(self,
:param res_key: the key for storage of inference regulatory network result.
:param seed: optional random seed for the regressors. Default None.
:param cache: whether to use cache files. Need to provide adj.csv, motifs.csv and auc.csv.
:param save: whether to save the result as a file.
:param save_regulons: whether to save regulons into a csv file.
:param save_loom: whether to save the result as a loom file.
:param fn_prefix: the prefix of file name for saving regulons or loom.
:param method: the method to inference GRN, 'grnboost' or 'hotspot'.
:param ThreeD_slice: whether to use 3D slice data.
:param prune_kwargs: dict, others parameters of pyscenic.prune.prune2df
:param prune_kwargs: dict, other parameters of pyscenic.prune.prune2df.
:param hotspot_kwargs: dict, other parameters for 'hotspot' method.
:return: Computation result of inference regulatory network is stored in self.result where the result key is 'regulatory_network_inference'.
""" # noqa
matrix = self.stereo_exp_data.to_df()
df = self.stereo_exp_data.to_df()
self.use_raw = use_raw
if use_raw and self.stereo_exp_data.raw is None:
raise Exception("The raw data is not found, you need to run 'raw_checkpoint()' first.")

if use_raw:
logger.info('the raw expression matrix will be used.')
matrix = self.stereo_exp_data.raw.to_df()
else:
logger.info('if you have done some normalized processing, the normalized expression matrix will be used.')
matrix = self.stereo_exp_data.to_df()
# df = self.stereo_exp_data.to_df()
df = matrix.copy(deep=True)

if num_workers is None:
num_workers = cpu_count()
Expand Down Expand Up @@ -114,8 +131,13 @@ def main(self,
adjacencies = self.grn_inference(matrix, genes=target_genes, tf_names=tfsf, num_workers=num_workers,
seed=seed, cache=cache, cache_res_key=cache_res_key)
elif method == 'hotspot':
hotspot_kwargs_adjusted = {}
for key, value in hotspot_kwargs.items():
if key in ('tf_list', 'jobs', 'cache', 'cache_res_key', 'ThreeD_slice'):
continue
hotspot_kwargs_adjusted[key] = value
adjacencies = self.hotspot_matrix(tf_list=tfsf, jobs=num_workers, cache=cache, cache_res_key=cache_res_key,
ThreeD_slice=ThreeD_slice)
ThreeD_slice=ThreeD_slice, **hotspot_kwargs_adjusted)

modules = self.get_modules(adjacencies, df)
# 4. Regulons prediction aka cisTarget
Expand All @@ -131,26 +153,27 @@ def main(self,
'regulons': self.regulon_dict,
'auc_matrix': auc_matrix,
'adjacencies': adjacencies,
'motifs': motifs
# 'motifs': motifs
}
self.stereo_exp_data.tl.reset_key_record('regulatory_network_inference', res_key)

if save:
self.regulons_to_csv(regulons)
if save_regulons:
self.regulons_to_csv(regulons, fn_prefix=fn_prefix)
# self.regulons_to_json(regulons)
self.to_loom(df, auc_matrix, regulons)
if save_loom:
self.to_loom(df, auc_matrix, regulons, fn_prefix=fn_prefix)
# self.to_cytoscape(regulons, adjacencies, 'Zfp354c')

@staticmethod
def input_hotspot(data):
def input_hotspot(counts, position):
"""
Extract needed information to construct a Hotspot instance from StereoExpData data
:param data:
:return: a dictionary
"""
# 3. use dataframe and position array, StereoExpData as well
counts = data.to_df().T # gene x cell
position = data.position
# counts = data.to_df().T # gene x cell
# position = data.position
num_umi = counts.sum(axis=0) # total counts for each cell
# Filter genes
gene_counts = (counts > 0).sum(axis=1)
Expand Down Expand Up @@ -207,12 +230,16 @@ def hotspot_matrix(self,
logger.info('cached file not found, running hotspot now')

global hs
data = self.stereo_exp_data
hotspot_data = RegulatoryNetworkInference.input_hotspot(data)
# data = self.stereo_exp_data
if self.use_raw:
counts = self.stereo_exp_data.raw.to_df().T
else:
counts = self.stereo_exp_data.to_df().T
hotspot_data = RegulatoryNetworkInference.input_hotspot(counts, self.stereo_exp_data.position)

if ThreeD_slice:
arr2 = data.position_z
position_3D = np.concatenate((data.position, arr2), axis=1)
arr2 = self.stereo_exp_data.position_z
position_3D = np.concatenate((self.stereo_exp_data.position, arr2), axis=1)
hotspot_data['position'] = position_3D

hs = hotspot.Hotspot.legacy_init(hotspot_data['counts'],
Expand Down Expand Up @@ -476,7 +503,7 @@ def regulons_to_json(self, regulon_list: list, fn='regulons.json'):
with open(fn, 'w') as f:
json.dump(regulon_dict, f, indent=4)

def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv'):
def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv', fn_prefix: str = None):
"""
Save regulon_list (df2regulons output) into a csv file.
:param regulon_list:
Expand All @@ -488,12 +515,21 @@ def regulons_to_csv(self, regulon_list: list, fn: str = 'regulon_list.csv'):
for key in regulon_dict.keys():
regulon_dict[key] = ";".join(regulon_dict[key])
# Write to csv file
if fn_prefix is not None:
fn = f"{fn_prefix}_{fn}"
with open(fn, 'w') as f:
w = csv.writer(f)
w.writerow(["Regulons", "Target_genes"])
w.writerows(regulon_dict.items())

def to_loom(self, matrix: pd.DataFrame, auc_matrix: pd.DataFrame, regulons: list, loom_fn: str = 'grn_output.loom'):
def to_loom(
self,
matrix: pd.DataFrame,
auc_matrix: pd.DataFrame,
regulons: list,
loom_fn: str = 'grn_output.loom',
fn_prefix: str = None
):
"""
Save GRN results in one loom file
:param matrix:
Expand All @@ -502,6 +538,8 @@ def to_loom(self, matrix: pd.DataFrame, auc_matrix: pd.DataFrame, regulons: list
:param loom_fn:
:return:
"""
if fn_prefix is not None:
loom_fn = f"{fn_prefix}_{loom_fn}"
export2loom(
ex_mtx=matrix,
auc_mtx=auc_matrix,
Expand Down
17 changes: 9 additions & 8 deletions stereo/algorithm/regulatory_network_inference/plot_grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def grn_dotplot(self,

if cluster_res_key in self.stereo_exp_data.cells._obs.columns:
meta = pd.DataFrame({
'bin': self.stereo_exp_data.cells.cell_name,
'bins': self.stereo_exp_data.cells.cell_name,
'group': self.stereo_exp_data.cells._obs[cluster_res_key].tolist()
})
else:
Expand Down Expand Up @@ -265,7 +265,7 @@ def auc_heatmap(
if network_res_key not in self.pipeline_res:
logger.info(f"The result specified by {network_res_key} is not exists.")

fig = sns.clustermap(
g = sns.clustermap(
self.pipeline_res[network_res_key]['auc_matrix'],
pivot_kws=pivot_kws,
method=method,
Expand All @@ -287,7 +287,7 @@ def auc_heatmap(
cbar_pos=cbar_pos,
)

return fig
return g.figure

@plot_scale
@reorganize_coordinate
Expand Down Expand Up @@ -552,13 +552,13 @@ def auc_heatmap_by_group(
cbar_pos=cbar_pos,
)

return g
return g.figure

def spatial_scatter_by_regulon_3D(
self,
network_res_key: str = 'regulatory_network_inference',
reg_name: str = None,
fn: str = None,
# fn: str = None,
view_vertical: int = 0,
view_horizontal: int = 0,
show_axis: bool = False,
Expand All @@ -581,8 +581,8 @@ def spatial_scatter_by_regulon_3D(
elif '(+)' not in reg_name:
reg_name = reg_name + '(+)'

if fn is None:
fn = f'{reg_name.strip("(+)")}.pdf'
# if fn is None:
# fn = f'{reg_name.strip("(+)")}.pdf'

# prepare plotting data
arr2 = self.stereo_exp_data.position_z
Expand Down Expand Up @@ -621,7 +621,8 @@ def spatial_scatter_by_regulon_3D(
plt.box(False)
plt.axis('off')
plt.colorbar(sc, shrink=0.35)
plt.savefig(fn, format='pdf')
# plt.savefig(fn, format='pdf')
return fig


def get_n_hls_colors(num):
Expand Down
2 changes: 1 addition & 1 deletion stereo/algorithm/time_series_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(
:param p_val_combination: p_value combination method to use, choosing from ['fisher', 'mean', 'FDR']
:param cluster_number: number of cluster
The parameters below are all key word arguments and only for `other` `run_method`.
All the parameters below are key word arguments and only for `other` `run_method`.
:param spatial_weight: the weight to combine spatial feature
:param n_spatial_feature: n top features to combine of spatial feature
Expand Down
2 changes: 1 addition & 1 deletion stereo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"""

# version
version = '0.14.0b1'
version = '1.0.0'
7 changes: 5 additions & 2 deletions stereo/core/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def cell_name(self) -> np.ndarray:
:return: cell name
"""
return self.__based_ann_data.obs_names.values.astype(str)
return self.__based_ann_data.obs_names.values.astype('U')

@cell_name.setter
def cell_name(self, name: np.ndarray):
Expand All @@ -244,7 +244,10 @@ def cell_name(self, name: np.ndarray):
"""
if not isinstance(name, np.ndarray):
raise TypeError('cell name must be a np.ndarray object.')
self.__based_ann_data._inplace_subset_obs(name)
if name.size != self.__based_ann_data.n_obs:
raise ValueError(f'The length of cell names must be {self.__based_ann_data.n_obs}, but now is {name.size}')
self.__based_ann_data.obs_names = name
# self.__based_ann_data._inplace_subset_obs(name)

@property
def total_counts(self):
Expand Down
7 changes: 5 additions & 2 deletions stereo/core/gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def gene_name(self) -> np.ndarray:
:return: genes name.
"""
return self.__based_ann_data.var_names.values.astype(str)
return self.__based_ann_data.var_names.values.astype('U')

@gene_name.setter
def gene_name(self, name: np.ndarray):
Expand All @@ -176,7 +176,10 @@ def gene_name(self, name: np.ndarray):
"""
if not isinstance(name, np.ndarray):
raise TypeError('gene name must be a np.ndarray object.')
self.__based_ann_data._inplace_subset_var(name)
if name.size != self.__based_ann_data.n_vars:
raise ValueError(f'The length of gene names must be {self.__based_ann_data.n_vars}, but now is {name.size}')
self.__based_ann_data.var_names = name
# self.__based_ann_data._inplace_subset_var(name)

def to_df(self):
return self.__based_ann_data.var
Loading

0 comments on commit 2f4dcd5

Please sign in to comment.