diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index cd02e88..9eaa3be 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -681,6 +681,10 @@ def read_hest_wsi(wsi: WSI, width, height): return SpatialData(tables=new_table, images=images, shapes=shapes) + def ensembleID_to_gene(self): + ensembleID_to_gene(self, inplace=True) + + class VisiumHESTData(HESTData): def __init__(self, adata: sc.AnnData, # type: ignore @@ -1239,11 +1243,48 @@ def unify_gene_names(adata: sc.AnnData, species="human", drop=False) -> sc.AnnDa if drop: adata = adata[:, ~remaining] - + # TODO return dict map of renamed, and remaining - return adata +def ensembleID_to_gene(st: HESTData, inplace=False, filter_na = False) -> HESTData: + """ + Converts ensemble gene IDs of a HESTData object using Biomart annotations and filter out genes with no matching Ensembl ID + + Args: + st (HESTData): HESTData object + inplace (bool): whenever to perform the changes in placce. Defaults to True. + filter_na (bool): whenever to filter genes that are not valid ensemble IDs. Defaults to False. + + Returns: + HESTData: HESTData object with gene names instead of ensemble gene IDs + """ + import scanpy as sc + if not inplace: + st = st.copy() + + import scanpy as sc + species = st.meta['species'] + org = "hsapiens" if species == "Homo sapiens" else "mmusculus" + + annotations = sc.queries.biomart_annotations(org=org,attrs=['ensembl_gene_id', 'external_gene_name'], use_cache=True) + ensembl_to_gene_name = dict(zip(annotations['ensembl_gene_id'], annotations['external_gene_name'])) + + + st.adata.var['gene_name'] = st.adata.var_names.map(ensembl_to_gene_name, na_action=None) + + if filter_na: + st.adata.var_names = st.adata.var['gene_name'].fillna('') + else: + st.adata.var['gene_name'] = st.adata.var['gene_name'].where(st.adata.var['gene_name'].notna(), st.adata.var_names) + + valid_genes = st.adata.var['gene_name'].notna() + st.adata = st.adata[:, valid_genes] + + + return st + + def save_spatial_plot(adata: sc.AnnData, save_path: str, name: str='', key='total_counts', pl_kwargs={}): """Save the spatial plot from that sc.AnnData diff --git a/src/hest/__init__.py b/src/hest/__init__.py index 9af7b95..7f32c34 100644 --- a/src/hest/__init__.py +++ b/src/hest/__init__.py @@ -3,7 +3,7 @@ from .utils import tiff_save, find_pixel_size_from_spot_coords, write_10X_h5, get_k_genes, SpotPacking from .autoalign import autoalign_visium from .readers import * -from .HESTData import HESTData, read_HESTData, load_hest, iter_hest +from .HESTData import HESTData, read_HESTData, load_hest, iter_hest, ensembleID_to_gene from .segmentation.cell_segmenters import segment_cellvit __all__ = [ @@ -20,5 +20,6 @@ 'autoalign_visium', 'write_10X_h5', 'HESTData', - 'segment_cellvit' + 'segment_cellvit', + 'ensembleID_to_gene' ] diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 14bb514..12bf79e 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -17,6 +17,7 @@ from hest.autoalign import autoalign_visium from hest.readers import VisiumReader +from hest.HESTData import ensembleID_to_gene from hest.utils import load_image @@ -131,6 +132,12 @@ def setUpClass(self): else: self.sts = hest.load_hest('hest_data', id_list) + + def test_conversion_ensembleID(self): + for idx, st in enumerate(self.sts): + with self.subTest(st_object=idx): + ensembleID_to_gene(st) + def test_tissue_seg(self): for idx, st in enumerate(self.sts):