Skip to content

Commit

Permalink
update for SpaTrack
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Nov 7, 2024
1 parent 6de0dc5 commit 211aa88
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
47 changes: 20 additions & 27 deletions stereo/algorithm/ms_spa_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,21 @@ def main(

def transfer_matrix(
self,
data_indices: List[Union[str, int]] = None,
data1_index: Union[str, int] = None,
data2_index: Union[str, int] = None,
spatial_key: str = 'spatial',
alpha: float = 0.1,
epsilon = 0.01,
rho = np.inf,
G_1 = None,
G_2 = None,
epsilon: float = 0.01,
rho: float = np.inf,
G_1: float = None,
G_2: float = None,
**kwargs
):
"""
Squentially calculate transfer matrix between each two time specified by data_indices.
Calculates transfer matrix between two time.
:param data_indices: A list of indices or names in the ms_data of the data to calculate transfer matrix, defaults to None
:param data1_index: The index in the ms_data of the first data.
:param data2_index: The index in the ms_data of the second data.
:param spatial_key: The key to get position information of cells, defaults to 'spatial'
:param alpha: Alignment tuning parameter. Note:0 <= alpha <= 1.
When ``alpha = 0`` only the gene expression data is taken into account,
Expand All @@ -59,27 +61,18 @@ def transfer_matrix(
:param G_2: Distance matrix within spatial data 2 (spots, spots), defaults to None
"""
assert spatial_key is not None, 'spatial_key must be provided'
if data_indices is None:
data_indices = self.ms_data.names
data_list_to_calculate: List[StereoExpData] = [
self.ms_data[di] for di in data_indices
]
data_names = [
self.ms_data.names[di] if isinstance(di, int) else di for di in data_indices
]
sp_existed = [spatial_key in data.cells_matrix for data in data_list_to_calculate]
assert all(sp_existed), 'spatial_key must be existed in all data'

transfer_matrices = {}
for i in range(len(data_indices) - 1):
data1 = data_list_to_calculate[i]
data2 = data_list_to_calculate[i + 1]
transfer_matrices[(data_names[i], data_names[i + 1])] = transfer_matrix(
data1, data2, layer=None, spatial_key=spatial_key, alpha=alpha, epsilon=epsilon,
rho=rho, G_1=G_1, G_2=G_2, **kwargs
)
data1 = self.ms_data[data1_index]
data2 = self.ms_data[data2_index]
assert spatial_key in data1.cells_matrix and spatial_key in data2.cells_matrix, 'spatial_key must be existed in both data'
data1_name = self.ms_data.names[data1_index] if isinstance(data1_index, int) else data1_index
data2_name = self.ms_data.names[data2_index] if isinstance(data2_index, int) else data2_index
if 'transfer_matrices' not in self.pipeline_res['spa_track']:
self.pipeline_res['spa_track']['transfer_matrices'] = {}
tm = transfer_matrix(data1, data2, layer=None, spatial_key=spatial_key, alpha=alpha, epsilon=epsilon,
rho=rho, G_1=G_1, G_2=G_2, **kwargs)

self.pipeline_res['spa_track']['transfer_spatial_key'] = spatial_key
self.pipeline_res['spa_track']['transfer_matrices'] = transfer_matrices
self.pipeline_res['spa_track']['transfer_matrices'][(data1_name, data2_name)] = tm

def generate_animate_input(
self,
Expand Down
4 changes: 2 additions & 2 deletions stereo/algorithm/spt/multiple_time/transfer_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def transfer_matrix(
exp_martix_2 = data2.get_exp_matrix(use_raw=False, layer=layer, to_array=True)

## Calculate spatial coordinates diatances
spa_dist_1 = spatial_dist(data1)
spa_dist_2 = spatial_dist(data2)
spa_dist_1 = spatial_dist(data1, spatial_key=spatial_key)
spa_dist_2 = spatial_dist(data2, spatial_key=spatial_key)

## Calculate gene expression dissimilarity
M = gene_dist(exp_martix_1, exp_martix_2)
Expand Down

0 comments on commit 211aa88

Please sign in to comment.