Skip to content

Commit

Permalink
add ability to provide reference annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
mparker2 committed Oct 19, 2020
1 parent dbbd8f8 commit 74f5b62
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 65 deletions.
26 changes: 16 additions & 10 deletions lib2pass/bamparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def find_introns(aln, stranded=True):
get_junction_overhang_size(right_tag)
)
# info is intron motif
introns.append([left, right, junc_overhang, info])
introns.append([left, right, junc_overhang, ln, info])
intron_motifs.append(info)
pos = right

Expand All @@ -114,10 +114,11 @@ def find_introns(aln, stranded=True):
else:
strand = infer_strand_from_intron_motifs(intron_motifs, read_strand)

for start, end, overhang, motif in introns:
n_introns = len(introns)
for i, (start, end, overhang, length, motif) in enumerate(introns, 1):
if strand == '-':
motif = motif.translate(RC)[::-1]
yield chrom, start, end, strand, motif, overhang
yield chrom, start, end, strand, motif, overhang, length


def build_donor_acceptor_ncls(introns, intron_counts, intron_jads, dist=20):
Expand Down Expand Up @@ -188,6 +189,7 @@ def assign_primary(chrom, start, end, strand, inv_trees):

def fetch_introns_for_interval(bam_fn, chrom, start, end, stranded):
motifs = {}
lengths = {}
counts = Counter()
intron_jads = Counter()
with pysam.AlignmentFile(bam_fn) as bam:
Expand All @@ -196,12 +198,13 @@ def fetch_introns_for_interval(bam_fn, chrom, start, end, stranded):
# which start before beginning of specified interval
if aln.reference_start < start:
continue
for *i, m, junc_overhang in find_introns(aln, stranded):
for *i, m, ov, ln in find_introns(aln, stranded):
i = tuple(i)
motifs[i] = m
lengths[i] = ln
counts[i] += 1
intron_jads[i] = max(intron_jads[i], junc_overhang)
return motifs, counts, intron_jads
intron_jads[i] = max(intron_jads[i], ov)
return motifs, lengths, counts, intron_jads


def get_bam_intervals(bam_fn, batch_size):
Expand All @@ -215,14 +218,16 @@ def get_bam_intervals(bam_fn, batch_size):

def merge_intron_res(res):
motifs = {}
lengths = {}
counts = Counter()
intron_jads = Counter()
for m, c, j in res:
for m, l, c, j in res:
motifs.update(m)
lengths.update(l)
counts += c
for i, jad in j.items():
intron_jads[i] = max(intron_jads[i], jad)
return motifs, counts, intron_jads
return motifs, lengths, counts, intron_jads


def parse_introns(bam_fn, primary_splice_local_dist,
Expand All @@ -238,10 +243,11 @@ def parse_introns(bam_fn, primary_splice_local_dist,
bam_fn, *inv, stranded)
for inv in get_bam_intervals(bam_fn, batch_size)
)
motifs, counts, intron_jads = merge_intron_res(res)
motifs, lengths, counts, intron_jads = merge_intron_res(res)

introns = list(motifs.keys())
motifs = [motifs[i] for i in introns]
lengths = [lengths[i] for i in introns]
counts = [counts[i] for i in introns]
jad_label = [intron_jads[i] for i in introns]

Expand All @@ -255,5 +261,5 @@ def parse_introns(bam_fn, primary_splice_local_dist,
is_primary_donor.append(d)
is_primary_acceptor.append(a)

return (introns, motifs, counts, jad_label,
return (introns, motifs, lengths, counts, jad_label,
is_primary_donor, is_primary_acceptor)
127 changes: 122 additions & 5 deletions lib2pass/decisiontree.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,76 @@
import logging
from operator import itemgetter
import re
import numpy as np

from sklearn.preprocessing import quantile_transform
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.ensemble import ExtraTreesClassifier

def dt1_pred(intron_motif, jad_label, is_primary_donor, is_primary_acceptor,

log = logging.getLogger('2passtools')


DT1_DENOVO_FEATURES = [
'is_canonical_motif', 'jad',
'is_primary_donor', 'is_primary_acceptor',
'intron_length_quantile',
]

DT2_DENOVO_FEATURES = [
'jad', 'is_primary_donor', 'is_primary_acceptor',
'intron_length_quantile',
'donor_lr_score', 'acceptor_lr_score',
]


def format_feature_importances(feature_names, feature_importances, width=10):
max_size = max(feature_importances)
point_size = max_size / width
pad_to = max([len(x) for x in feature_names])
feature_importances = {fn: fi for fn, fi in zip(feature_names, feature_importances)}
feature_importances = sorted(feature_importances.items(), key=itemgetter(1), reverse=True)
fmt = ''
for fn, fi in feature_importances:
rpad = ' ' * (pad_to - len(fn))
fn += rpad
bar = '*' * int(round(fi / point_size))
fmt += f'{fn} {bar} {fi:.1f}\n'
return fmt


def _de_novo_pred(X, y, feature_names, classifier='decision_tree'):
if classifier == 'random_forest':
log.info('Using extremely random forest')
clf = ExtraTreesClassifier(n_estimators=250, bootstrap=True, oob_score=True)
clf.fit(X, y)
log.debug('Feature importance:')
log.debug(format_feature_importances(feature_names, clf.feature_importances_))
pred = clf.oob_decision_function_[:, 1]
# in the unlikely event dt1_pred contains NaNs
# (can happen when n_estimators is not big enough)
pred[np.isnan(pred)] = 0
pred = pred >= 0.5

else:
clf = DecisionTreeClassifier(
max_depth=5,
min_samples_split=100,
min_impurity_decrease=0.005,
)
clf.fit(X, y)
log.debug('Tree structure:')
log.debug(export_text(clf, feature_names=feature_names))
pred = clf.predict(X)
return pred.astype(int)


def dt1_pred(intron_motif, jad_labels, is_primary_donor, is_primary_acceptor,
motif_regex='GTAG|GCAG|ATAG', jad_size_threshold=4):
motif_regex = re.compile(motif_regex)
is_canon = np.asarray([bool(motif_regex.match(m)) for m in intron_motif])

jad_label = np.asarray(jad_label) >= jad_size_threshold
jad_labels = np.asarray(jad_labels) >= jad_size_threshold

is_primary_donor = np.asarray(is_primary_donor, dtype=bool)
is_primary_acceptor = np.asarray(is_primary_acceptor, dtype=bool)
Expand All @@ -16,7 +79,36 @@ def dt1_pred(intron_motif, jad_label, is_primary_donor, is_primary_acceptor,
return (jad_label & is_canon) | (is_primary & is_canon)


def dt2_pred(jad_label,
def dt1_de_novo_pred(intron_motif, intron_lengths,
jad_labels, is_primary_donor, is_primary_acceptor,
is_annot, motif_regex='GTAG|GCAG|ATAG',
classifier='decision_tree'):
motif_regex = re.compile(motif_regex)
is_canon = np.asarray([int(bool(motif_regex.match(m))) for m in intron_motif])

jad_labels = np.asarray(jad_labels)

is_primary_donor = np.asarray(is_primary_donor)
is_primary_acceptor = np.asarray(is_primary_acceptor)

intron_length_quantile = quantile_transform(
np.asarray(intron_lengths).reshape(-1, 1)
).ravel()

X = np.stack(
[
is_canon, jad_labels,
is_primary_donor, is_primary_acceptor,
intron_length_quantile
],
axis=1
)
y = np.asarray(is_annot, dtype=np.int)
pred = _de_novo_pred(X, y, DT1_DENOVO_FEATURES, classifier=classifier)
return pred


def dt2_pred(jad_labels,
is_primary_donor,
is_primary_acceptor,
donor_lr_score,
Expand All @@ -25,7 +117,7 @@ def dt2_pred(jad_label,
high_conf_thresh=0.6,
jad_size_threshold=4):

jad_label = np.asarray(jad_label) >= jad_size_threshold
jad_labels = np.asarray(jad_labels) >= jad_size_threshold
is_primary_donor = np.asarray(is_primary_donor, dtype=bool)
is_primary_acceptor = np.asarray(is_primary_acceptor, dtype=bool)
donor_lr_score = np.asarray(donor_lr_score, dtype=np.float64)
Expand All @@ -38,4 +130,29 @@ def dt2_pred(jad_label,
seq_high_conf = ((donor_lr_score >= high_conf_thresh) &
(acceptor_lr_score >= high_conf_thresh))

return (jad_label & seq_low_conf) | (is_primary & seq_high_conf)
return (jad_labels & seq_low_conf) | (is_primary & seq_high_conf)


def dt2_de_novo_pred(intron_lengths, jad_labels,
is_primary_donor, is_primary_acceptor,
donor_lr_score, acceptor_lr_score,
is_annot, classifier='decision_tree'):
jad_labels = np.asarray(jad_labels)

is_primary_donor = np.asarray(is_primary_donor)
is_primary_acceptor = np.asarray(is_primary_acceptor)

intron_length_quantile = quantile_transform(
np.asarray(intron_lengths).reshape(-1, 1)
).ravel()

X = np.stack(
[
jad_labels, is_primary_donor, is_primary_acceptor,
intron_length_quantile, donor_lr_score, acceptor_lr_score,
],
axis=1
)
y = np.asarray(is_annot, dtype=np.int)
pred = _de_novo_pred(X, y, DT2_DENOVO_FEATURES, classifier=classifier)
return pred
Loading

0 comments on commit 74f5b62

Please sign in to comment.