Skip to content

Commit

Permalink
Update _filter_file.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ravibandaru-lab authored Jan 31, 2025
1 parent 1f1aabb commit 20480e9
Showing 1 changed file with 154 additions and 185 deletions.
339 changes: 154 additions & 185 deletions src/finaletoolkit/utils/_filter_file.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,69 @@
from __future__ import annotations
import tempfile as tf
import subprocess
import traceback
import logging
import warnings
import gzip
import pysam

def validate_deprecated_args(old_arg, new_arg, old_name, new_name):
if old_arg is not None:
warnings.warn(f"{old_name} is deprecated. Use {new_name} instead.",
category=DeprecationWarning,
stacklevel=2)
if new_arg is None:
return old_arg
raise ValueError(f'{old_name} and {new_name} cannot both be specified.')
return new_arg

def validate_input_file(input_file):
valid_suffixes = {".gz", ".bam", ".cram"}
if not any(input_file.endswith(suffix) for suffix in valid_suffixes):
raise ValueError(f"Input file should have one of the following suffixes: {', '.join(valid_suffixes)}")
return next(suffix for suffix in valid_suffixes if input_file.endswith(suffix))

def run_subprocess(cmd: str, error_msg: str = "Command failed", verbose: bool = False, logger=None):
try:
if verbose:
logger.info(f"Running: {cmd}")
subprocess.run(cmd, shell=True, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"{error_msg}: {str(e)}")
raise

def filter_bed_entries(infile, min_length=None, max_length=None, quality_threshold=30):
mapq_column = None
for line in infile:
if line.startswith('#'):
continue

parts = line.strip().split('\t')
if len(parts) < 4:
continue

# Determine MAPQ column if not yet found
if mapq_column is None:
if parts[3].isnumeric():
mapq_column = 3
elif len(parts) >= 5 and parts[4].isnumeric():
mapq_column = 4
else:
continue

try:
start = int(parts[1])
end = int(parts[2])
length = end - start
score = float(parts[mapq_column])

if ((min_length is None or length >= min_length) and
(max_length is None or length <= max_length) and
score >= quality_threshold):
yield line

except (ValueError, IndexError):
continue

def filter_file(
input_file: str,
whitelist_file: str | None = None,
Expand Down Expand Up @@ -43,7 +101,7 @@ def filter_file(
Maximum length for reads/intervals
intersect_policy: str, optional
Specifies how to determine whether fragments are in interval for
whitelisting and blacklisting functionality.'midpoint' (default)
whitelisting and blacklisting functionality. 'midpoint' (default)
calculates the central coordinate of each fragment and only
selects the fragment if the midpoint is in the interval.
'any' includes fragments with any overlap with the interval.
Expand All @@ -63,7 +121,7 @@ def filter_file(
output_file : str
Path to the filtered output file.
"""

logger = logging.getLogger(__name__)
if verbose:
print(
f"""
Expand All @@ -79,207 +137,118 @@ def filter_file(
verbose: {verbose}
"""
)

# Pass aliases and check for conflicts
if fraction_low is not None and min_length is None:
min_length = fraction_low
warnings.warn("fraction_low is deprecated. Use min_length instead.",
category=DeprecationWarning,
stacklevel=2)
elif fraction_low is not None and min_length is not None:
warnings.warn("fraction_low is deprecated. Use min_length instead.",
category=DeprecationWarning,
stacklevel=2)
raise ValueError(
'fraction_low and min_length cannot both be specified')

if fraction_high is not None and max_length is None:
max_length = fraction_high
warnings.warn("fraction_high is deprecated. Use max_length instead.",
category=DeprecationWarning,
stacklevel=2)
elif fraction_high is not None and max_length is not None:
warnings.warn("fraction_high is deprecated. Use max_length instead.",
category=DeprecationWarning,
stacklevel=2)
raise ValueError(
'fraction_high and max_length cannot both be specified.')
logging.basicConfig(level=logging.INFO)

if input_file.endswith(".gz"):
suffix = ".gz"
elif input_file.endswith(".bgz"):
suffix = ".bgz"
elif input_file.endswith(".bam"):
suffix = ".bam"
elif input_file.endswith(".cram"):
suffix = ".cram"
# Pass aliases and check for conflicts
min_length = validate_deprecated_args(fraction_low, min_length, "fraction_low", "min_length")
max_length = validate_deprecated_args(fraction_high, max_length, "fraction_high", "max_length")
print(f"min_length: {min_length}")
print(f"max_length: {max_length}")
suffix = validate_input_file(input_file)

if intersect_policy == "midpoint":
intersect_param = "-f 0.500"
elif intersect_policy == "any":
intersect_param = ""
else:
raise ValueError('Input file should have suffix .bam, .cram, .bgz, or .gz')

# create tempfile to contain filtered output
if output_file is None:
_, output_file = tf.mkstemp(suffix=suffix)
elif not output_file.endswith(suffix) and output_file != '-':
raise ValueError('Output file should share same suffix as input file.')

intersect = "-f 0.500" if intersect_policy == "midpoint" else ""
raise ValueError("intersect_policy must be 'midpoint' or 'any'")

pysam.set_verbosity(pysam.set_verbosity(0))

with tf.TemporaryDirectory() as temp_dir:
temp_1 = f"{temp_dir}/output1{suffix}"
temp_2 = f"{temp_dir}/output2{suffix}"
temp_3 = f"{temp_dir}/output3{suffix}"
if input_file.endswith(('.bam', '.cram')):
# create temp dir to store intermediate sorted file
if whitelist_file is not None:
try:
subprocess.run(
f"bedtools intersect -abam {input_file} -b {whitelist_file} {intersect} > {temp_1} && samtools index {temp_1}",
shell=True,
check=True)
except Exception as e:
print(e)
traceback.print_exc()
exit(1)
if whitelist_file:
run_subprocess(
f"bedtools intersect -abam {input_file} -b {whitelist_file} {intersect_param} > {temp_1} && "
f"samtools index {temp_1}",
error_msg="Whitelist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(
f"cp {input_file} {temp_1}", shell=True, check=True)
if blacklist_file is not None:
try:
subprocess.run(
f"bedtools intersect -abam {temp_1} -b {blacklist_file} -v {intersect} > {temp_2} && samtools index {temp_2}",
shell=True,
check=True)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"cp {input_file} {temp_1}", verbose=verbose, logger=logger)
if blacklist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -abam {temp_1} -b {blacklist_file} -v {intersect_param} > {temp_2} && "
f"samtools index {temp_2}",
error_msg="Blacklist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(
f"mv {temp_1} {temp_2}", shell=True, check=True)

try:
subprocess.run(
f"samtools view {temp_2} -F 3852 -f 3 -b -h -o {temp_3} -q {quality_threshold} -@ {workers}",
shell=True,
check=True)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_1} {temp_2}", verbose=verbose, logger=logger)

run_subprocess(
f"samtools view {temp_2} -F 3852 -f 3 -b -h -o {temp_3} -q {quality_threshold} -@ {workers}",
error_msg="Quality filtering failed",
verbose=verbose,
logger=logger
)

# filter for reads on different reference and length
with pysam.AlignmentFile(temp_3, 'rb',threads=workers//3) as in_file:
with pysam.AlignmentFile(
output_file, 'wb', template=in_file, threads=workers-workers//3) as out_file:
# Length filtering and final output
pysam.set_verbosity(0)
with pysam.AlignmentFile(temp_3, 'rb', threads=workers//3) as in_file:
with pysam.AlignmentFile(output_file, 'wb', template=in_file, threads=workers-workers//3) as out_file:
for read in in_file:
if (
read.reference_name == read.next_reference_name
and (max_length is None
or read.template_length <= max_length)
and (min_length is None
or read.template_length >= min_length)
):
if (read.reference_name == read.next_reference_name and
(max_length is None or read.template_length <= max_length) and
(min_length is None or read.template_length >= min_length)):
out_file.write(read)
outfile.flush()

if output_file != '-':
# generate index for output_file
try:
subprocess.run(
f'samtools index {output_file} {output_file}.bai',
shell=True,
check=True
run_subprocess(
f'samtools index {output_file}',
error_msg="Index creation failed",
verbose=verbose,
logger=logger
)

elif input_file.endswith('.gz'):
with gzip.open(input_file, 'rt') as infile, open(temp_1, 'w') as outfile:
for line in filter_bed_entries(infile, min_length, max_length, quality_threshold):
outfile.write(line)
outfile.flush()

if whitelist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -a {temp_1} -b {whitelist_file} {intersect_param} > {temp_2}",
error_msg="Whitelist filtering failed",
verbose=verbose,
logger=logger
)
except Exception:
traceback.print_exc()
exit(1)

elif input_file.endswith('.gz') or input_file.endswith('.bgz'):
with gzip.open(input_file, 'r') as infile, open(temp_1, 'w') as outfile:
mapq_column = 0 # 1-index for sanity when comparing with len()
for line in infile:
line = line.decode('utf-8')
parts = line.strip().split('\t')
if len(parts) < max(mapq_column,4) or line.startswith('#'):
continue

if mapq_column == 0:
if parts[4-1].isnumeric():
mapq_column = 4
elif len(parts) >= 5 and parts[5-1].isnumeric():
mapq_column = 5
else:
continue
try:
start = int(parts[1])
end = int(parts[2])
length = end - start
score = None
try:
score = float(parts[mapq_column-1])
except ValueError:
pass

passes_length_restriction = True

if min_length is not None and length < min_length:
passes_length_restriction = False

if max_length is not None and length > max_length:
passes_length_restriction = False

passes_quality_restriction = True
if score is None or score < quality_threshold:
passes_quality_restriction = False

if passes_length_restriction and passes_quality_restriction:
outfile.write(line)
except ValueError:
continue
if whitelist_file is not None:
try:
subprocess.run(
f"bedtools intersect -a {temp_1} -b {whitelist_file} {intersect} > {temp_2}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
else:
subprocess.run(f"mv {temp_1} {temp_2}", shell=True, check=True)

if blacklist_file is not None:
try:
subprocess.run(
f"bedtools intersect -v -a {temp_2} -b {blacklist_file} {intersect} > {temp_3}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_1} {temp_2}", verbose=verbose, logger=logger)

if blacklist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -v -a {temp_2} -b {blacklist_file} {intersect_param} > {temp_3}",
error_msg="Blacklist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(f"mv {temp_2} {temp_3}", shell=True, check=True)
try:
subprocess.run(
f"bgzip -@ {workers} -c {temp_3} > {output_file}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_2} {temp_3}", verbose=verbose, logger=logger)

# Compression and indexing
run_subprocess(
f"bgzip -@ {workers} -c {temp_3} > {output_file}",
error_msg="Compression failed",
verbose=verbose,
logger=logger
)

if output_file != '-':
# generate index for output_file
try:
subprocess.run(
f'tabix -p bed {output_file}',
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
else:
raise ValueError("Input file must be a BAM, CRAM, or bgzipped BED file.")
run_subprocess(
f'tabix -p bed {output_file}',
error_msg="Index creation failed",
verbose=verbose,
logger=logger
)
return output_file

0 comments on commit 20480e9

Please sign in to comment.