diff --git a/src/genomic_features/ensembl/ensembldb.py b/src/genomic_features/ensembl/ensembldb.py index 93bc8ae..94cc9c8 100644 --- a/src/genomic_features/ensembl/ensembldb.py +++ b/src/genomic_features/ensembl/ensembldb.py @@ -1,12 +1,15 @@ from __future__ import annotations import warnings +from collections.abc import Sequence from functools import cached_property from itertools import product from pathlib import Path from typing import Final, Literal import ibis +import ibis.expr.types as ir +import ibis.selectors as s import requests from ibis import deferred from ibis.expr.types import Table as IbisTable @@ -19,9 +22,7 @@ PKG_CACHE_DIR = "genomic-features" -BIOC_ANNOTATION_HUB_URL = ( - "https://bioconductorhubs.blob.core.windows.net/annotationhub" -) +BIOC_ANNOTATION_HUB_URL = "https://bioconductorhubs.blob.core.windows.net/annotationhub" ANNOTATION_HUB_URL = ( "https://annotationhub.bioconductor.org/metadata/annotationhub.sqlite3" ) @@ -53,7 +54,7 @@ def annotation( """ try: sqlite_file_path = retrieve_annotation( - f'{BIOC_ANNOTATION_HUB_URL}/AHEnsDbs/v{version}/EnsDb.{species}.v{version}.sqlite' + f"{BIOC_ANNOTATION_HUB_URL}/AHEnsDbs/v{version}/EnsDb.{species}.v{version}.sqlite" ) if backend == "sqlite": @@ -163,6 +164,7 @@ def genes( cols: list[str] | None = None, filter: AbstractFilterExpr = filters.EmptyFilter(), join_type: Literal["inner", "left"] = "inner", + order_by: Sequence[str] | str | None = None, ) -> DataFrame: """Get gene annotations. @@ -175,6 +177,8 @@ def genes( Filters to apply to the query. join_type How to perform joins during the query (if cols or filters requires them). + order_by + Columns to order the results by. Usage @@ -190,7 +194,7 @@ def genes( if "gene_id" not in cols: # genes always needs gene_id cols.append("gene_id") - query = self._build_query(table, cols, filter, join_type) + query = self._build_query(table, cols, filter, join_type, order_by) return self._execute_query(query) def transcripts( @@ -198,6 +202,7 @@ def transcripts( cols: list[str] | None = None, filter: AbstractFilterExpr = filters.EmptyFilter(), join_type: Literal["inner", "left"] = "inner", + order_by: Sequence[str] | str | None = None, ) -> DataFrame: """Get transcript annotations. @@ -206,10 +211,12 @@ def transcripts( cols Which columns to retrieve from the database. Can be from other tables. Returns all transcript columns if None. - filters + filter Filters to apply to the query. join_type How to perform joins during the query (if cols or filters requires them). + order_by + Columns to order the results by. Usage @@ -221,6 +228,7 @@ def transcripts( cols = self.list_columns(table) # get all columns cols = cols.copy() + # Require primary key in output if "tx_id" not in cols: cols.append("tx_id") @@ -228,7 +236,7 @@ def transcripts( if ("tx_seq_start" in cols or "tx_seq_end" in cols) and "seq_name" not in cols: cols.append("seq_name") - query = self._build_query(table, cols, filter, join_type) + query = self._build_query(table, cols, filter, join_type, order_by) return self._execute_query(query) def exons( @@ -236,6 +244,7 @@ def exons( cols: list[str] | None = None, filter: AbstractFilterExpr = filters.EmptyFilter(), join_type: Literal["inner", "left"] = "inner", + order_by: Sequence[str] | str | None = None, ) -> DataFrame: """Get exons table. @@ -268,7 +277,7 @@ def exons( ) and "seq_name" not in cols: cols.append("seq_name") - query = self._build_query(table, cols, filter, join_type) + query = self._build_query(table, cols, filter, join_type, order_by) return self._execute_query(query) def _execute_query(self, query: IbisTable) -> DataFrame: @@ -291,6 +300,13 @@ def _build_query( cols: list[str], filter: AbstractFilterExpr, join_type: Literal["inner", "left"] = "inner", + order_by: str + | ir.Column + | s.Selector + | Sequence[str] + | Sequence[ir.Column] + | Sequence[s.Selector] + | None = None, ) -> IbisTable: """Build a query for the genomic features table.""" # Finalize cols @@ -310,8 +326,29 @@ def _build_query( query = self._join_query(tables, start_with=table, join_type=join_type) else: query = self.db.table(table) + # add filter - query = query.filter(filter.convert()).select(cols).order_by(cols) + query = query.filter(filter.convert()) + query = query.select(cols) + + if order_by is not None: + # Custom ordering is provided + query = query.order_by(order_by) + else: + # Default ordering + order_by = [] + if "seq_name" in cols: + order_by = ["seq_name"] + if "gene_seq_start" in cols: + order_by.extend(["gene_seq_start"]) + if "tx_seq_start" in cols: + order_by.extend(["tx_seq_start"]) + if "exon_seq_start" in cols: + order_by.extend(["exon_seq_start"]) + + order_by.extend([c for c in cols if "id" in c]) + query = query.order_by(order_by) + return query def _join_query( diff --git a/tests/test_basic.py b/tests/test_basic.py index c127400..04ca769 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -3,35 +3,51 @@ import genomic_features as gf +ENSEMBL_RELEASE = 108 + def test_package_has_version(): assert gf.__version__ is not None -def test_genes(): - genes = gf.ensembl.annotation("Hsapiens", 108).genes() +@pytest.mark.parametrize("backend", ["sqlite", "duckdb"]) +def test_genes(backend): + genes = gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE, backend=backend).genes() assert isinstance(genes, pd.DataFrame) + # Test sort order + genes_resorted = genes.sort_values( + ["seq_name", "gene_seq_start", "gene_id"] + ).reset_index(drop=True) + pd.testing.assert_frame_equal(genes, genes_resorted) + def test_missing_version(): with pytest.raises(ValueError): gf.ensembl.annotation("Hsapiens", 86) -def test_repr(): - result = repr(gf.ensembl.annotation("Hsapiens", 108)) - expected = "EnsemblDB(organism='Homo sapiens', ensembl_release='108', genome_build='GRCh38')" +def test_invalid_backend(): + with pytest.raises(ValueError): + gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE, backend="bad_idea") + +def test_repr(): + result = repr(gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE)) + expected = f"EnsemblDB(organism='Homo sapiens', ensembl_release='{ENSEMBL_RELEASE}', genome_build='GRCh38')" assert result == expected def test_invalid_join(): with pytest.raises(ValueError, match=r"Invalid join type: flarb"): - gf.ensembl.annotation("Hsapiens", 108).genes(cols=["tx_id"], join_type="flarb") + gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE).genes( + cols=["tx_id"], join_type="flarb" + ) -def test_exons(): - ensdb = gf.ensembl.annotation("Hsapiens", 108) +@pytest.mark.parametrize("backend", ["sqlite", "duckdb"]) +def test_exons(backend): + ensdb = gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE, backend=backend) exons = ensdb.exons() pd.testing.assert_index_equal( @@ -44,3 +60,50 @@ def test_exons(): pd.testing.assert_index_equal(exons_id.columns, pd.Index(["exon_id"])) assert exons_id.shape[0] == exons.shape[0] + + # Test sort order + exons_resorted = exons.sort_values( + ["seq_name", "exon_seq_start", "exon_id"] + ).reset_index(drop=True) + pd.testing.assert_frame_equal(exons, exons_resorted) + + +@pytest.mark.parametrize("backend", ["sqlite", "duckdb"]) +def test_join_sort_ordering(backend): + ensdb = gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE, backend=backend) + df = ensdb.genes( + [ + "seq_name", + "gene_seq_start", + "gene_seq_end", + "exon_id", + "exon_seq_start", + "exon_seq_end", + ] + ) + + # Test sort order + df_resorted = df.sort_values( + ["seq_name", "gene_seq_start", "exon_seq_start", "exon_id", "gene_id"] + ).reset_index(drop=True) + pd.testing.assert_frame_equal(df, df_resorted) + + +@pytest.mark.parametrize("backend", ["sqlite", "duckdb"]) +def test_custom_ordering(backend): + ensdb = gf.ensembl.annotation("Hsapiens", ENSEMBL_RELEASE, backend=backend) + df = ensdb.genes( + [ + "seq_name", + "gene_seq_start", + "gene_seq_end", + "exon_id", + "exon_seq_start", + "exon_seq_end", + ], + order_by="exon_id", + ) + + # Test sort order + df_resorted = df.sort_values(["exon_id"]).reset_index(drop=True) + pd.testing.assert_frame_equal(df, df_resorted) diff --git a/tests/test_filters.py b/tests/test_filters.py index 1daff6b..2b09a5a 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -194,10 +194,7 @@ def test_negation(hsapiens108): assert result.shape[0] == 22894 -@pytest.mark.parametrize("backend", ["sqlite", "duckdb"]) -def test_seqs_as_int(backend): - hsapiens108 = gf.ensembl.annotation("Hsapiens", 108, backend=backend) - +def test_seqs_as_int(hsapiens108): result_w_int = hsapiens108.genes(filter=filters.SeqNameFilter(1)) result_w_str = hsapiens108.genes(filter=filters.SeqNameFilter("1")) pd.testing.assert_frame_equal(