diff --git a/src/genomic_features/ensembl/ensembldb.py b/src/genomic_features/ensembl/ensembldb.py index a731470..729dba8 100644 --- a/src/genomic_features/ensembl/ensembldb.py +++ b/src/genomic_features/ensembl/ensembldb.py @@ -7,6 +7,7 @@ from typing import Final, Literal import ibis +import numpy as np import requests from ibis import deferred from ibis.expr.types import Table as IbisTable @@ -269,6 +270,56 @@ def chromosomes(self) -> DataFrame: """ return self.db.table("chromosome").execute() + def promoters( + self, + filter: _filters.AbstractFilterExpr = filters.EmptyFilter(), + upstream: int = 2000, + downstream: int = 200, + canonical_transcripts: bool = False, + ) -> DataFrame: + """Get promoter annotations. + + Parameters + ---------- + filter + Filter expression to apply to the genes table. + upstream + Number of base pairs upstream of the TSS (default: 2000). + downstream + Number of base pairs downstream of the TSS (default: 200). + canonical_transcripts + If True, return only canonical transcript for each gene (default: False). + + Returns + ------- + DataFrame + A table of promoter annotations. + """ + # TODO: change to get transcript table with gene level columns + # something like: + # tx_table = self.transcripts(cols = set(cols + ['seq_strand', 'seq_name', 'tx_is_canonical']), filter = filter) + tx_table = self.genes(filter=filter) + + # Get promoter region based on strand + # strand = 1 |>>>>>>>>>>>>>>| + # strand = -1 |<<<<<<<<<<<<<<| + # Tx SS: * * + # Promoter ------ ------ + tx_ss = np.where( + tx_table["seq_strand"] == 1, + tx_table["gene_seq_start"], + tx_table["gene_seq_end"], + ) + tx_table["promoter_seq_start"] = np.where( + tx_table["seq_strand"] == 1, tx_ss - upstream, tx_ss - downstream + ) + tx_table["promoter_seq_end"] = np.where( + tx_table["seq_strand"] == 1, tx_ss + downstream, tx_ss + upstream + ) + # if canonical_transcripts: + # tx_table = tx_table[tx_table["tx_is_canonical"] == 1] + return tx_table + def _build_query( self, table: Literal["gene", "tx", "exon"], diff --git a/tests/test_filters.py b/tests/test_filters.py index 33d5756..4bbee4d 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -205,3 +205,22 @@ def test_seqs_as_int(hsapiens108): pd.testing.assert_frame_equal(result_w_ints, result_w_strs) pd.testing.assert_frame_equal(result_w_ints, result_w_mixed) + + +def test_promoters(hsapiens108): + promoters = hsapiens108.promoters() + assert isinstance(promoters, pd.DataFrame) + promoters = hsapiens108.promoters(upstream=100, downstream=100) + assert ((promoters.promoter_seq_end - promoters.promoter_seq_start) == 200).all() + promoters = hsapiens108.promoters(upstream=1000, downstream=100) + assert ((promoters.promoter_seq_end - promoters.promoter_seq_start) == 1100).all() + # Test strandedness + promoters = hsapiens108.promoters(upstream=1000, downstream=100) + assert ( + promoters[promoters.seq_strand == -1].promoter_seq_start + == promoters[promoters.seq_strand == -1].gene_seq_end - 100 + ).all() + assert ( + promoters[promoters.seq_strand == 1].promoter_seq_start + == promoters[promoters.seq_strand == 1].gene_seq_start - 1000 + ).all()