Skip to content

Commit

Permalink
Replace string-based file extension checks with pathlib methods
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbonet committed Nov 14, 2024
1 parent 6f6cd8d commit 959de3c
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 39 deletions.
5 changes: 3 additions & 2 deletions benchmark/read_vcf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
from pathlib import Path
from .utils import create_benchmark_test


Expand Down Expand Up @@ -68,7 +69,7 @@ def read_vcf_pysam(path):
@pytest.mark.parametrize("reader,name", READERS)
def test_vcf_readers(benchmark, reader, name, path, memory_profile):
"""Benchmark readers and verify output"""
if not path.endswith(".vcf.gz"):
path = path + ".vcf.gz"
if path.suffixes[-2:] != ['.vcf', '.gz']:
path = path.with_suffix('.vcf.gz')
ref_array = read_vcf_snputils(path)
create_benchmark_test(benchmark, reader, path, name, ref_array, memory_profile)
12 changes: 9 additions & 3 deletions snputils/ancestry/io/local/write/adm_mapping_vcf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import numpy as np
from typing import Dict
from pathlib import Path

from snputils.ancestry.genobj.local import LocalAncestryObject
from snputils.snp.genobj.snpobj import SNPObject
Expand Down Expand Up @@ -56,10 +57,15 @@ def write(self):

# Define the output file format, ensuring it has the correct ancestry-specific suffix
file_extension = (".vcf", ".bcf")
if not self.__file.endswith(file_extension):
output_file = self.__file + f"_{anc_string}.vcf"
file_path = Path(self.__file)

# Check if file has one of the specified extensions
if file_path.suffix not in file_extension:
# If file does not have the correct extension, default to ".vcf"
output_file = file_path.with_name(f"{file_path.stem}_{anc_string}.vcf")
else:
output_file = f"{self.__file[:-4]}_{anc_string}{self.__file[-4:]}"
# If file has the correct extension, insert the ancestry string before the extension
output_file = file_path.with_name(f"{file_path.stem}_{anc_string}{file_path.suffix}")

# Format start and end positions for the VCF file
if self.__laiobj.physical_pos is not None:
Expand Down
12 changes: 6 additions & 6 deletions snputils/ancestry/io/local/write/msp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def write(self):
log.info(f"LAI object contains: {self.laiobj.n_samples} samples, {self.laiobj.n_ancestries} ancestries.")

# Define the required file extension for the output
file_extension = (".msp")
file_extension = ".msp"

# Append '.msp' extension to filename if not already present
if not self.__filename.endswith(file_extension):
self.__filename += file_extension
# Append '.msp' extension to __file if not already present
if not self.__file.name.endswith(file_extension):
self.__file = self.__file.with_name(self.__file.name + file_extension)

# Prepare columns for the DataFrame
columns = ["spos", "epos", "sgpos", "egpos", "n snps"]
Expand Down Expand Up @@ -101,7 +101,7 @@ def write(self):
log.info(f"Writing MSP file to '{self.file}'...")

# Save the DataFrame to the .msp file in tab-separated format
lai_df.to_csv(self.filename, sep="\t", index=False, header=False)
lai_df.to_csv(self.__file, sep="\t", index=False, header=False)

# Construct the second line for the output file containing the column headers
second_line = "#chm" + "\t" + "\t".join(columns)
Expand All @@ -117,7 +117,7 @@ def write(self):
)

# Open the file for reading and prepend the first line
with open(self.filename, "r+") as f:
with open(self.__file, "r+") as f:
content = f.read()
f.seek(0,0)
f.write(first_line.rstrip('\r\n') + '\n' + second_line + '\n' + content)
Expand Down
4 changes: 2 additions & 2 deletions snputils/snp/io/read/__test__/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def data_path():
if not file_path.exists():
print(f"Downloading {file_name} to {data_path}. This may take a while...")
urllib.request.urlretrieve(url, file_path)
if file_name.endswith(".zip"):
if file_path.suffix == ".zip":
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)
subprocess.run(
["chmod", "+x", data_path / "plink2"], cwd=str(data_path)
)

if file_name.endswith(".vcf.gz"):
if file_path.suffixes[-2:] == ['.vcf', '.gz']:
# Subset sample files
subset_file = data_path / "subset.txt"
if not subset_file.exists():
Expand Down
28 changes: 16 additions & 12 deletions snputils/snp/io/read/vcf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Optional, List
from pathlib import Path
import gzip
import allel
import numpy as np
Expand Down Expand Up @@ -102,26 +103,29 @@ def _get_vcf_names(vcf_path: str):
Parameters
----------
vcf_path: str
vcf_path: str or Path
The path to the VCF file.
Returns
-------
List[str]
List of column names.
"""
if vcf_path.endswith('.vcf.gz'):
with gzip.open(vcf_path, "rt") as ifile:
for line in ifile:
if line.startswith("#CHROM"):
vcf_names = [x.strip() for x in line.split('\t')]
break
vcf_path = Path(vcf_path)
if vcf_path.suffixes[-2:] == ['.vcf', '.gz']:
open_func = gzip.open
mode = 'rt'
elif vcf_path.suffix == '.vcf':
open_func = open
mode = 'r'
else:
with open(vcf_path, "r") as ifile:
for line in ifile:
if line.startswith("#CHROM"):
vcf_names = [x.strip() for x in line.split('\t')]
break
raise ValueError(f"Unsupported file extension: {vcf_path.suffixes}")

with open_func(vcf_path, mode) as ifile:
for line in ifile:
if line.startswith("#CHROM"):
vcf_names = [x.strip() for x in line.split('\t')]
break

return vcf_names

Expand Down
11 changes: 6 additions & 5 deletions snputils/snp/io/write/bed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import numpy as np
import pgenlib as pg
from pathlib import Path

from snputils.snp.genobj import SNPObject

Expand All @@ -19,14 +20,14 @@ class BEDWriter:

def __init__(self, snpobj: SNPObject, filename: str):
self.__snpobj = snpobj
self.__filename = filename
self.__filename = Path(filename)

def write(self):
"""Writes the SNPObject to bed/bim/fam formats."""

# Save .bed file
if not self.__filename.endswith((".bed")):
self.__filename += ".bed"
if self.__filename.suffix != '.bed':
self.__filename = self.__filename.with_name(self.__filename.name + '.bed')

log.info(f"Writing .bed file: {self.__filename}")

Expand Down Expand Up @@ -57,8 +58,8 @@ def write(self):
log.info(f"Finished writing .bed file: {self.__filename}")

# Remove .bed from the file name
if self.__filename.endswith(("bed")):
self.__filename = self.__filename[:-4]
if self.__filename.suffix == '.bed':
self.__filename = self.__filename.with_suffix('')

# Save .fam file
log.info(f"Writing .fam file: {self.__filename}")
Expand Down
9 changes: 5 additions & 4 deletions snputils/snp/io/write/pgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import polars as pl
import pgenlib as pg
from pathlib import Path

from snputils.snp.genobj.snpobj import SNPObject

Expand All @@ -26,15 +27,15 @@ def __init__(self, snpobj: SNPObject, filename: str):
TODO: add support for parallel writing by chromosome.
"""
self.__snpobj = snpobj
self.__filename = filename
self.__filename = Path(self.__filename)

def write(self):
"""
Writes the SNPObject data to .pgen, .psam, and .pvar files.
"""
file_extension = (".pgen", ".psam", ".pvar")
if self.__filename.endswith(file_extension):
self.__filename = self.__filename[:-5]
file_extensions = (".pgen", ".psam", ".pvar")
if self.__filename.suffix in file_extensions:
self.__filename = self.__filename.with_suffix('')
self.__file_extension = ".pgen"

self.write_pvar()
Expand Down
11 changes: 6 additions & 5 deletions snputils/snp/io/write/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import numpy as np
import joblib
from pathlib import Path

from snputils.snp.genobj import SNPObject

Expand All @@ -28,7 +29,7 @@ def __init__(self, snpobj: SNPObject, filename: str, n_jobs: int = -1, phased: b
"maternal/paternal" format.
"""
self.__snpobj = snpobj
self.__filename = filename
self.__filename = Path(self.__filename)
self.__n_jobs = n_jobs
self.__phased = phased

Expand All @@ -42,10 +43,10 @@ def write(self, chrom_partition: bool = False):
"""
self.__chrom_partition = chrom_partition

file_extension = (".vcf", ".bcf")
if self.__filename.endswith(file_extension):
self.__file_extension = self.__filename[-4:]
self.__filename = self.__filename[:-4]
file_extensions = (".vcf", ".bcf")
if self.__filename.suffix in file_extensions:
self.__file_extension = self.__filename.suffix
self.__filename = self.__filename.with_suffix('')
else:
self.__file_extension = ".vcf"

Expand Down

0 comments on commit 959de3c

Please sign in to comment.