diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index 7838c665..678d00b7 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -113,7 +113,7 @@ def read(self, region: str = None, samples: list[str] = None): # transpose the GT matrix so that samples are rows and variants are columns self.data = self.data.transpose((1, 0, 2)) - def check_biallelic(self): + def check_biallelic(self, discard_also=False): """ Check that each genotype is composed of only two alleles @@ -126,6 +126,11 @@ def check_biallelic(self): converted to bool ValueError If any of the genotypes have more than two alleles + + Parameters + ---------- + discard_also : bool, optional + If True, discard any multiallelic variants without raising a ValueError """ if self.data.dtype == np.bool_: raise AssertionError("All genotypes are already biallelic") @@ -134,11 +139,15 @@ def check_biallelic(self): multiallelic = np.any(self.data[:, :, :2] > 1, axis=2) if np.any(multiallelic): samp_idx, variant_idx = np.nonzero(multiallelic) - raise ValueError( - "Variant with ID {} at POS {}:{} is multiallelic for sample {}".format( - *tuple(self.variants[variant_idx[0]])[:3], self.samples[samp_idx[0]] + if discard_also: + self.data = np.delete(self.data, variant_idx, axis=1) + self.variants = np.delete(self.variants, variant_idx) + else: + raise ValueError( + "Variant with ID {} at POS {}:{} is multiallelic for sample {}".format( + *tuple(self.variants[variant_idx[0]])[:3], self.samples[samp_idx[0]] + ) ) - ) self.data = self.data.astype(np.bool_) def check_phase(self): diff --git a/pyproject.toml b/pyproject.toml index 2340a3fc..b2416d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ haptools = 'haptools.__main__:main' [tool.black] line-length = 88 -experimental-string-processing = true +preview = true [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/test_data.py b/tests/test_data.py index 284c3b70..1093441e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -78,6 +78,27 @@ def test_load_genotypes(): gts.to_MAC() +def test_load_genotypes_discard_multiallelic(): + expected = get_expected_genotypes() + + # can we load the data from the VCF? + gts = Genotypes(DATADIR.joinpath("simple.vcf")) + gts.read() + + # make a copy for later + data_copy = gts.data.copy().astype(np.bool_) + variant_shape = list(gts.variants.shape) + variant_shape[0] -= 1 + + # force one of the SNPs to have more than one allele and check that it gets dicarded + gts.data[1, 1, 1] = 2 + gts.check_biallelic(discard_also=True) + + data_copy_without_biallelic = np.delete(data_copy, [1], axis=1) + np.testing.assert_equal(gts.data, data_copy_without_biallelic) + assert gts.variants.shape == tuple(variant_shape) + + def test_load_genotypes_subset(): expected = get_expected_genotypes()