Skip to content

Commit

Permalink
feat: Support multiallelic sites
Browse files Browse the repository at this point in the history
This PR implements the multiallelic calculation of `pi`.

---------

feat: Multiallelic calculation of `dxy` (ksamuk#88)

This PR implements the multiallelic calculation of `dxy`.

test: Add test coverage over multiallelism in FST calculation (ksamuk#90)

AFAICT the scikit-allel functions for FST calculation support
multiallelism out of the box. This PR adds unit test coverage over
multiallelic sites, but no updates to the pixy calculation functions
appeared necessary.

Co-authored-by: Matt Stone <matt@fulcrumgenomics.com>
Co-authored-by: Erin McAuley <erin@fulcrumgenomics.com>

fix: test
  • Loading branch information
msto committed Feb 4, 2025
1 parent 6298c68 commit e4b438b
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 39 deletions.
7 changes: 7 additions & 0 deletions pixy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ def main() -> None: # noqa: C901
),
required=False,
)
optional.add_argument(
"--include_multiallelic_snps",
choices=["yes", "no"],
default="no",
help=("Multiallelic SNPs within the VCF will be included during calculation.(default=no)."),
required=False,
)
optional.add_argument(
"--bypass_invariant_check",
choices=["yes", "no"],
Expand Down
5 changes: 5 additions & 0 deletions pixy/args_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class PixyArgs:
populations_df: a pandas DataFrame derived from a user-specified path to a headerless
tab-separated populations file
num_cores: number of CPUs to utilize for parallel processing (default = 1)
include_multiallelic_snps: If True, include multiallelic sites in the analysis
bypass_invariant_check: whether to allow computation of stats without invariant sites
(this option is never recommended and defaults to False)
bed_df: a pandas DataFrame derived from a user-specified path to a headerless BED file.
Expand Down Expand Up @@ -77,6 +78,7 @@ class PixyArgs:
temp_file: Path
chromosomes: List[str]
bypass_invariant_check: bool
include_multiallelic_snps: bool
num_cores: int = 1
fst_type: FSTEstimator = FSTEstimator.WC
output_prefix: str = "pixy"
Expand Down Expand Up @@ -640,6 +642,8 @@ def check_and_validate_args( # noqa: C901
"defined in the population file."
)

include_multiallelic_snps: bool = args.include_multiallelic_snps == "yes"

logger.info("All initial checks passed!")
stats: List[PixyStat] = [PixyStat[stat.upper()] for stat in args.stats]
tmp_path: Path = _generate_tmp_path(output_dir=output_folder)
Expand All @@ -650,6 +654,7 @@ def check_and_validate_args( # noqa: C901
populations_df=populations_df,
num_cores=args.n_cores,
bypass_invariant_check=bypass_invariant_check,
include_multiallelic_snps=include_multiallelic_snps,
bed_df=bed_df,
output_dir=Path(output_folder),
output_prefix=output_prefix,
Expand Down
100 changes: 66 additions & 34 deletions pixy/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,69 @@
# these are reimplementations of the original functions


# helper function for calculation of pi
# for the given site (row of the count table) count number of differences, number of comparisons,
# and number missing.
# uses number of haploid samples (n_haps) to determine missing data
def count_diff_comp_missing(row: AlleleCountsArray, n_haps: int) -> Tuple[int, int, int]:
"""
Calculates site-specific statistics for `pi` calculation across populations.
Helper function for the calculation of pi.
For the given site (row of the count table), count the number of differences, number of
comparisons, and number of missing. The function uses the number of haploid samples (n_haps) to
compute the number of expected genotype comparisons and determine the count of missing.
Args:
row: counts of each of the two alleles at a given site
row: counts of each allele at a given site
n_haps: number of haploid samples in the population
Returns:
diffs: number of differences between the populations
comps: number of comparisons between the populations
missing: number of missing between the populations
A tuple `(diffs, comps, missing)`, where `diffs` is the number of differences within the
population, `comps` is the number of comparisons made within the population, and `missing`
is the difference between the actual number of comparisons and the total possible (based on
the number of haploid samples).
"""
diffs = row[1] * row[0]
gts = row[1] + row[0]
comps = int(special.comb(N=gts, k=2)) # calculate combinations, return an integer
missing = int(special.comb(N=n_haps, k=2)) - comps
n_gts: int = np.sum(row) # number of observed genotypes

# number of possible pairwise comparisons, if all samples are called
n_possible_comps: int = int(special.comb(N=n_haps, k=2))

if n_gts == 0:
# No observed genotypes in the row
return 0, 0, n_possible_comps

# Find the highest index of an observed allele, and assume it is the allelism of the site
# (If the variant technically has other alleles, they zero out anyways)
observed_alleles: NDArray = np.nonzero(row)[0]
allelism = np.argmax(observed_alleles) + 1

comps = int(special.comb(N=n_gts, k=2)) # calculate combinations, return an integer
missing = n_possible_comps - comps

# Use shortcut: the number of differences is the sum of all pairwise products of the observed
# allele counts
diffs = 0
for i in range(allelism - 1):
for j in range(i + 1, allelism):
diffs += row[i] * row[j]

return diffs, comps, missing


# function for vectorized calculation of pi from a pre-filtered scikit-allel genotype matrix
def calc_pi(gt_array: GenotypeArray) -> PiResult:
"""
Given a filtered genotype matrix, calculate `pi`.
This function implements vectorized calculation of pi from a pre-filtered scikit-allel genotype
matrix. This function does not support filtering of the input by population - it simply
calculates pi on all of the provided samples.
Args:
gt_array: the GenotypeArray representing the counts of each of the two alleles
gt_array: a GenotypeArray representing the calls of each variant at each filtered site in a
given population. This array must be pre-filtered to the population of interest.
Returns:
avg_pi: proportion of total differences across total comparisons. "NA" if no valid data.
total_diffs: sum of the number of differences between the populations
total_comps: sum of the number of comparisons between the populations
total_missing: sum of the number of missing between the populations
The average `pi` and total difference, comparison, and missing counts over all sites in the
input array.
"""
# counts of each of the two alleles at each site
allele_counts: AlleleCountsArray = gt_array.count_alleles(max_allele=1)
allele_counts: AlleleCountsArray = gt_array.count_alleles()

# the number of (haploid) samples in the population
n_haps = gt_array.n_samples * gt_array.ploidy
Expand Down Expand Up @@ -111,28 +132,39 @@ def calc_dxy(pop1_gt_array: GenotypeArray, pop2_gt_array: GenotypeArray) -> DxyR
total_comps: sum of the number of comparisons between the populations
total_missing: sum of the number of missing between the populations
"""
if pop1_gt_array.n_variants != pop2_gt_array.n_variants:
raise ValueError("Input genotype matrices must have the same number of variants")

n_sites: int = pop1_gt_array.n_variants

# the counts of each of the two alleles in each population at each site
pop1_allele_counts: AlleleCountsArray = pop1_gt_array.count_alleles(max_allele=1)
pop2_allele_counts: AlleleCountsArray = pop2_gt_array.count_alleles(max_allele=1)
pop1_allele_counts: AlleleCountsArray = pop1_gt_array.count_alleles()
pop2_allele_counts: AlleleCountsArray = pop2_gt_array.count_alleles()

# the number of (haploid) samples in each population
pop1_n_haps: int = pop1_gt_array.n_samples * pop1_gt_array.ploidy
pop2_n_haps: int = pop2_gt_array.n_samples * pop2_gt_array.ploidy

# Find the highest index of an observed allele, and assume it is the allelism of every site
# (If a site has fewer alleles, they zero out)
allelism = np.max([pop1_allele_counts.max_allele(), pop2_allele_counts.max_allele()]) + 1

# the total number of differences between populations summed across all sites
total_diffs: int = (pop1_allele_counts[:, 0] * pop2_allele_counts[:, 1]) + (
pop1_allele_counts[:, 1] * pop2_allele_counts[:, 0]
)
total_diffs = np.sum(total_diffs, 0)
persite_diffs: NDArray = np.zeros(n_sites)
for i in range(allelism):
for j in range(allelism):
if i != j:
persite_diffs += pop1_allele_counts[:, i] * pop2_allele_counts[:, j]

# the total number of pairwise comparisons between sites
total_comps: int = (pop1_allele_counts[:, 0] + pop1_allele_counts[:, 1]) * (
pop2_allele_counts[:, 0] + pop2_allele_counts[:, 1]
)
total_comps = np.sum(total_comps, 0)
total_diffs: int = np.sum(persite_diffs)

# the total number of actual pairwise comparisons between sites, excluding missing calls
persite_comps: NDArray = np.sum(pop1_allele_counts, axis=1) * np.sum(pop2_allele_counts, axis=1)
assert persite_comps.shape == (n_sites,)
total_comps: int = np.sum(persite_comps)

# the total count of possible pairwise comparisons at all sites
total_possible: int = (pop1_n_haps * pop2_n_haps) * len(pop1_allele_counts)
total_possible: int = (pop1_n_haps * pop2_n_haps) * n_sites

# the amount of missing is possible comps - actual ('total') comps
total_missing: int = total_possible - total_comps
Expand Down
17 changes: 15 additions & 2 deletions pixy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def read_and_filter_genotypes(
# a string representation of the target region of the current window
window_region = chromosome + ":" + str(window_pos_1) + "-" + str(window_pos_2)

include_multiallelic_snps: bool = args.include_multiallelic_snps == "yes"

# read in data from the source VCF for the current window
callset = allel.read_vcf(
args.vcf,
Expand Down Expand Up @@ -348,9 +350,20 @@ def read_and_filter_genotypes(
pos_array = allel.SortedIndex(callset["variants/POS"])

# create a mask for biallelic snps and invariant sites
is_biallelic_snp = np.logical_and(
callset["variants/is_snp"][:] == 1,
callset["variants/numalt"][:] == 1,
)
is_multiallelic_snp = np.logical_and(
callset["variants/is_snp"][:] == 1,
callset["variants/numalt"][:] > 1,
)
is_invariant_site = callset["variants/numalt"][:] == 0

snp_invar_mask = np.logical_or(
np.logical_and(callset["variants/is_snp"][:] == 1, callset["variants/numalt"][:] == 1),
callset["variants/numalt"][:] == 0,
is_biallelic_snp,
is_invariant_site,
np.logical_and(include_multiallelic_snps, is_multiallelic_snp),
)

# remove rows that are NOT snps or invariant sites from the genotype array
Expand Down
10 changes: 9 additions & 1 deletion pixy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@ def __str__(self) -> str:

@dataclass
class PiResult:
"""A result from calculating pi."""
"""
A result from calculating pi.
Attributes:
avg_pi: proportion of total differences across total comparisons. "NA" if no valid data.
total_diffs: sum of the number of differences within the population
total_comps: sum of the number of comparisons within the population
total_missing: sum of the number of missing within the population
"""

avg_pi: Union[float, NA]
total_diffs: Union[int, NA]
Expand Down
1 change: 1 addition & 0 deletions stubs/allel.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class AlleleCountsArray(NumpyArrayWrapper):
def __init__(self, data: NDArray, copy: bool = False, **kwargs: Any) -> None: ...
def __getitem__(self, item: Any) -> Any: ...
def to_frequencies(self, fill: float = np.nan) -> NDArray: ...
def max_allele(self) -> NDArray: ...

class GenotypeArray(Genotypes):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions tests/args_validation/test_args_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_check_and_validate_args(
args.chromosomes = "X"
args.n_cores = 1
args.bypass_invariant_check = bypass_variant_check
args.include_multiallelic_snps = False
args.fst_type = "wc"
args.output_prefix = "test"
args.chunk_size = 100000
Expand Down
Loading

0 comments on commit e4b438b

Please sign in to comment.