Skip to content

Commit

Permalink
Improve SplitVcfBySamples to correctly handle allele-specific annotat…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
bbimber committed Feb 3, 2025
1 parent 2dd4da5 commit 5d9c678
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 9 deletions.
184 changes: 176 additions & 8 deletions src/main/java/com/github/discvrseq/walkers/SplitVcfBySamples.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import com.opencsv.RFC4180ParserBuilder;
import com.opencsv.exceptions.CsvValidationException;
import htsjdk.samtools.util.IOUtil;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.*;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.variantcontext.writer.VariantContextWriterBuilder;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLine;
import htsjdk.variant.vcf.*;
import org.apache.commons.collections4.list.UnmodifiableList;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -22,11 +21,16 @@
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.walkers.genotyper.AlleleSubsettingUtils;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAssignmentMethod;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;

/**
* This takes an input VCF and subsets it into new VCFs where each contains a subset of the original samples. You can either provide samplesPerVcf, in which case it will be subset based on # of samples, or you can provide a
Expand Down Expand Up @@ -73,6 +77,15 @@ public class SplitVcfBySamples extends VariantWalker {
@Argument(doc="This is a TSV file with two columns and no header, where column 1 is the filepath for an output VCF. Column 2 is a sample ID to write to this file. The file can contain multiple rows per output file. The purpose is to supply a list of samples->file, such that this tool can read the input VCF once, and write multiple output VCFs at the same time.", fullName = "sample-mapping-file", optional = true)
public GATKPath sampleMappingFile = null;

@Argument(fullName="recalculate-ac", doc="This will recalculate the AC, AF, and AN values after subsetting. See also --keep-original-ac", optional=true)
private boolean recalculateChrCounts = false;

@Argument(fullName="keep-original-ac", doc="Store the original AC, AF, and AN values after subsetting", optional=true)
private boolean keepOriginalChrCounts = false;

@Argument(fullName="original-ac-suffix", doc="If --keep-original-ac is selected, the original AC, AF, and AN values will be stored, but with this suffix (e.g., a suffix of .Orig would result in AF -> AF.Orig)", optional=true)
private String originalChrCountsSuffix = ".Orig";

List<List<String>> batches = new ArrayList<>();
List<VariantContextWriter> writers = new ArrayList<>();

Expand Down Expand Up @@ -134,11 +147,37 @@ private void prepareOutputsForSampleFile(List<String> samples, VCFHeader header)
writers.add(writer);
batches.add(new UnmodifiableList<>(fileToSamples.get(outputFile).stream().toList()));

VCFHeader outputHeader = new VCFHeader(header.getMetaDataInInputOrder(), fileToSamples.get(outputFile));
VCFHeader outputHeader = new VCFHeader(getHeaderLines(header), fileToSamples.get(outputFile));
writer.writeHeader(outputHeader);
}
}

private Set<VCFHeaderLine> getHeaderLines(VCFHeader header) {
Set<VCFHeaderLine> headerLines = new LinkedHashSet<>(header.getMetaDataInInputOrder());

if (recalculateChrCounts) {
if (header.getInfoHeaderLine(VCFConstants.ALLELE_COUNT_KEY) == null) {
headerLines.add(VCFStandardHeaderLines.getInfoLine(VCFConstants.ALLELE_COUNT_KEY));
}

if (header.getInfoHeaderLine(VCFConstants.ALLELE_FREQUENCY_KEY) == null) {
headerLines.add(VCFStandardHeaderLines.getInfoLine(VCFConstants.ALLELE_FREQUENCY_KEY));
}

if (header.getInfoHeaderLine(VCFConstants.ALLELE_NUMBER_KEY) == null) {
headerLines.add(VCFStandardHeaderLines.getInfoLine(VCFConstants.ALLELE_NUMBER_KEY));
}
}

if (recalculateChrCounts && keepOriginalChrCounts) {
headerLines.add(new VCFInfoHeaderLine(VCFConstants.ALLELE_COUNT_KEY + originalChrCountsSuffix, VCFHeaderLineCount.A, VCFHeaderLineType.Integer, "Original AC"));
headerLines.add(new VCFInfoHeaderLine(VCFConstants.ALLELE_FREQUENCY_KEY + originalChrCountsSuffix, VCFHeaderLineCount.A, VCFHeaderLineType.Float, "Original AF"));
headerLines.add(new VCFInfoHeaderLine(VCFConstants.ALLELE_NUMBER_KEY + originalChrCountsSuffix, 1, VCFHeaderLineType.Integer, "Original AN"));
}

return headerLines;
}

private void prepareOutputsForBatches(List<String> samples, VCFHeader header) {
Lists.partition(new ArrayList<>(samples), samplesPerVcf).forEach(l -> {
batches.add(List.copyOf(l)); // this returns an unmodifiable list
Expand All @@ -161,7 +200,6 @@ private void prepareOutputsForBatches(List<String> samples, VCFHeader header) {

int idx = 0;
File inputFile = getDrivingVariantsFeatureInput().toPath().toFile();
final Set<VCFHeaderLine> headerLines = new LinkedHashSet<>(header.getMetaDataInInputOrder());
for (List<String> batch : batches) {
idx++;
String name = inputFile.getName();
Expand All @@ -176,19 +214,19 @@ private void prepareOutputsForBatches(List<String> samples, VCFHeader header) {
VariantContextWriter writer = new VariantContextWriterBuilder().setOutputFile(output).setReferenceDictionary(getReferenceDictionary()).build();
writers.add(writer);

VCFHeader outputHeader = new VCFHeader(headerLines, new ArrayList<>(batch));
VCFHeader outputHeader = new VCFHeader(getHeaderLines(header), new ArrayList<>(batch));
writer.writeHeader(outputHeader);
}
}

@Override
public void apply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
int idx = 0;
for (List<String> batch : batches) {
for (List<String> sampleNames : batches) {
final VariantContextWriter writer = writers.get(idx);
idx++;

final VariantContext sub = variant.subContextFromSamples(new LinkedHashSet<>(batch), removeUnusedAlternates);
final VariantContext sub = subsetGenotypesBySampleNames(variant, new TreeSet<>(sampleNames), removeUnusedAlternates);
if (discardNonVariantSites) {
if (sub.getCalledChrCount() == 0) {
continue;
Expand All @@ -202,6 +240,136 @@ else if (sub.getGenotypes().stream().noneMatch(g -> !g.isFiltered() && g.isCalle
}
}

// NOTE: this is basically copied from SelectVariants. This is necessary to ensure we also subset INFO fields correctly
private VariantContext subsetGenotypesBySampleNames(final VariantContext vc, final SortedSet<String> samples, final boolean removeUnusedAlternates) {
// If no subsetting of samples or alleles happened, exit now
if (!removeUnusedAlternates && samples.size() == vc.getNSamples()) {
return vc;
}

// strip out the alternate alleles that aren't being used
final VariantContext sub = vc.subContextFromSamples(samples, removeUnusedAlternates);

GenotypesContext newGC;
if (sub.getNAlleles() != vc.getNAlleles()) {
// fix the PL and AD values if sub has fewer alleles than original vc
final GenotypesContext subGenotypesWithOldAlleles = sub.getGenotypes(); //we need sub for the right samples, but PLs still go with old alleles
newGC = sub.getNAlleles() == vc.getNAlleles() ? subGenotypesWithOldAlleles :
AlleleSubsettingUtils.subsetAlleles(subGenotypesWithOldAlleles, 0, vc.getAlleles(),
sub.getAlleles(), null, GenotypeAssignmentMethod.DO_NOT_ASSIGN_GENOTYPES);
} else {
newGC = sub.getGenotypes();
}

// since the VC has been subset (either by sample or allele), we need to strip out the MLE tags
final VariantContextBuilder builder = new VariantContextBuilder(sub);
builder.rmAttributes(Arrays.asList(GATKVCFConstants.MLE_ALLELE_COUNT_KEY,GATKVCFConstants.MLE_ALLELE_FREQUENCY_KEY));
builder.genotypes(newGC);
addAnnotations(builder, vc, sub.getSampleNames());

if (!vc.getAlleles().equals(sub.getAlleles())) {
int[] oldToKeyNoRef = sub.getAlternateAlleles().stream().map(a -> vc.getAlleleIndex(a) - 1).mapToInt(Integer::intValue).toArray();
int[] oldToKeyWithRef = sub.getAlleles().stream().map(vc::getAlleleIndex).mapToInt(Integer::intValue).toArray();
for (String attr : builder.getAttributes().keySet()) {
VCFInfoHeaderLine line = getHeaderForVariants().getInfoHeaderLine(attr);
if (line == null) {
continue;
}

if (line.getCountType() == VCFHeaderLineCount.A) {
builder.attribute(attr, getReorderedAttributes(sub.getAttribute(attr), oldToKeyNoRef));
}
else if (line.getCountType() == VCFHeaderLineCount.R) {
builder.attribute(attr, getReorderedAttributes(sub.getAttribute(attr), oldToKeyWithRef));
}
}
}

final VariantContext subset = builder.make();
return removeUnusedAlternates ? GATKVariantContextUtils.trimAlleles(subset,true,true) : subset;
}

private void addAnnotations(final VariantContextBuilder builder, final VariantContext originalVC, final Set<String> selectedSampleNames) {
if (recalculateChrCounts && keepOriginalChrCounts) {
final int[] indexOfOriginalAlleleForNewAllele;
final List<Allele> newAlleles = builder.getAlleles();
final int numOriginalAlleles = originalVC.getNAlleles();

// if the alleles already match up, we can just copy the previous list of counts
if (numOriginalAlleles == newAlleles.size()) {
indexOfOriginalAlleleForNewAllele = null;
}
// otherwise we need to parse them and select out the correct ones
else {
indexOfOriginalAlleleForNewAllele = new int[newAlleles.size() - 1];
Arrays.fill(indexOfOriginalAlleleForNewAllele, -1);

// note that we don't care about the reference allele at position 0
for (int newI = 1; newI < newAlleles.size(); newI++) {
final Allele newAlt = newAlleles.get(newI);
for (int oldI = 0; oldI < numOriginalAlleles - 1; oldI++) {
if (newAlt.equals(originalVC.getAlternateAllele(oldI), false)) {
indexOfOriginalAlleleForNewAllele[newI - 1] = oldI;
break;
}
}
}
}

if (originalVC.hasAttribute(VCFConstants.ALLELE_COUNT_KEY)) {
builder.attribute(VCFConstants.ALLELE_COUNT_KEY + originalChrCountsSuffix,
getReorderedAttributes(originalVC.getAttribute(VCFConstants.ALLELE_COUNT_KEY), indexOfOriginalAlleleForNewAllele));
}
if (originalVC.hasAttribute(VCFConstants.ALLELE_FREQUENCY_KEY)) {
builder.attribute(VCFConstants.ALLELE_FREQUENCY_KEY + originalChrCountsSuffix,
getReorderedAttributes(originalVC.getAttribute(VCFConstants.ALLELE_FREQUENCY_KEY), indexOfOriginalAlleleForNewAllele));
}
if (originalVC.hasAttribute(VCFConstants.ALLELE_NUMBER_KEY)) {
builder.attribute(VCFConstants.ALLELE_NUMBER_KEY + originalChrCountsSuffix, originalVC.getAttribute(VCFConstants.ALLELE_NUMBER_KEY));
}
}

if (recalculateChrCounts) {
VariantContextUtils.calculateChromosomeCounts(builder, false);
}

boolean sawDP = false;
int depth = 0;
for (final String sample : selectedSampleNames ) {
final Genotype g = originalVC.getGenotype(sample);
if (!g.isFiltered()) {
if (g.hasDP()) {
depth += g.getDP();
sawDP = true;
}
}
}

if (sawDP) {
builder.attribute(VCFConstants.DEPTH_KEY, depth);
}
}

private Object getReorderedAttributes(final Object attribute, final int[] oldToNewIndexOrdering) {
if (oldToNewIndexOrdering == null || attribute == null) {
return attribute;
}

// break the original attributes into separate tokens; unfortunately, this means being smart about class types
final Object[] tokens;
if (attribute.getClass().isArray()) {
tokens = (Object[]) attribute;
} else if (List.class.isAssignableFrom(attribute.getClass())) {
tokens = ((List) attribute).toArray();
} else {
tokens = attribute.toString().split(VCFConstants.INFO_FIELD_ARRAY_SEPARATOR);
}

Utils.validateArg(Arrays.stream(oldToNewIndexOrdering).allMatch(index -> index < tokens.length), () ->
"the old attribute has an incorrect number of elements: " + attribute);
return Arrays.stream(oldToNewIndexOrdering).mapToObj(index -> tokens[index]).collect(Collectors.toList());
}

@Override
public void closeTool() {
super.closeTool();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,57 @@ private ArgumentsBuilder getBaseArgs(File outDir) {
return getBaseArgs(outDir, 1);
}

@Test
public void testInfoSubsetting() throws Exception {
final File outDir = IOUtils.createTempDir("splitVcfBySamples.");
ArgumentsBuilder args = getBaseArgs(outDir, 1, "mergeVcfWithAlts.vcf");
args.addRaw("--recalculate-ac");
args.addRaw("--keep-original-ac");
args.addRaw("--original-ac-suffix");
args.addRaw(".new");

runCommandLine(args);

String actualMD5 = Utils.calculateFileMD5(new File(outDir, "mergeVcfWithAlts.1of2.vcf"));
String expectedMD5 = Utils.calculateFileMD5(getTestFile("mergeVcfWithAlts.1of2.vcf"));
Assert.assertEquals(actualMD5, expectedMD5);

actualMD5 = Utils.calculateFileMD5(new File(outDir, "mergeVcfWithAlts.2of2.vcf"));
expectedMD5 = Utils.calculateFileMD5(getTestFile("mergeVcfWithAlts.2of2.vcf"));
Assert.assertEquals(actualMD5, expectedMD5);
}

@Test
public void testInfoSubsettingRemoveAlts() throws Exception {
final File outDir = IOUtils.createTempDir("splitVcfBySamples.");
ArgumentsBuilder args = getBaseArgs(outDir, 1, "mergeVcfWithAlts.vcf");
args.addRaw("--recalculate-ac");
args.addRaw("--keep-original-ac");
args.addRaw("--original-ac-suffix");
args.addRaw(".new");
args.addRaw("--remove-unused-alternates");
args.addFlag("discardNonVariantSites");

runCommandLine(args);

String actualMD5 = Utils.calculateFileMD5(new File(outDir, "mergeVcfWithAlts.1of2.vcf"));
String expectedMD5 = Utils.calculateFileMD5(getTestFile("mergeVcfWithAltsRemoveAlts.1of2.vcf"));
Assert.assertEquals(actualMD5, expectedMD5);

actualMD5 = Utils.calculateFileMD5(new File(outDir, "mergeVcfWithAlts.2of2.vcf"));
expectedMD5 = Utils.calculateFileMD5(getTestFile("mergeVcfWithAltsRemoveAlts.2of2.vcf"));
Assert.assertEquals(actualMD5, expectedMD5);
}

private ArgumentsBuilder getBaseArgs(File outDir, @Nullable Integer samplesPerVcf) {
return getBaseArgs(outDir, samplesPerVcf, "mergeVcf3.vcf");
}

private ArgumentsBuilder getBaseArgs(File outDir, @Nullable Integer samplesPerVcf, String inputVcf) {
ArgumentsBuilder args = new ArgumentsBuilder();

args.addRaw("--variant");
File input = new File(testBaseDir, "mergeVcf3.vcf");
File input = new File(testBaseDir, inputVcf);
ensureVcfIndex(input);
args.addRaw(normalizePath(input));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
##fileformat=VCFv4.2
##INFO=<ID=RR,Number=R,Type=String,Description="Test REF annotation">
##INFO=<ID=RA,Number=A,Type=String,Description="Test ALT annotation">
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
##contig=<ID=1,length=16000>
##contig=<ID=2,length=16000>
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT Sample1 Sample3
1 61 . GT G 724.43 PASS RR=GT,G;RA=G GT 0/1 0/1
1 72 . T A,C 100 PASS RR=T,A,C;RA=A,C GT 0/1 1/1
1 73 . TT AG 100 PASS RR=TT,AG;RA=AG GT 0/1 0/0
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
##fileformat=VCFv4.2
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
##INFO=<ID=AC,Number=A,Type=Integer,Description="Allele count in genotypes, for each ALT allele, in the same order as listed">
##INFO=<ID=AC.new,Number=A,Type=Integer,Description="Original AC">
##INFO=<ID=AF,Number=A,Type=Float,Description="Allele Frequency, for each ALT allele, in the same order as listed">
##INFO=<ID=AF.new,Number=A,Type=Float,Description="Original AF">
##INFO=<ID=AN,Number=1,Type=Integer,Description="Total number of alleles in called genotypes">
##INFO=<ID=AN.new,Number=1,Type=Integer,Description="Original AN">
##INFO=<ID=RA,Number=A,Type=String,Description="Test ALT annotation">
##INFO=<ID=RR,Number=R,Type=String,Description="Test REF annotation">
##contig=<ID=1,length=16000>
##contig=<ID=2,length=16000>
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT Sample1
1 61 . GT G 724.43 PASS AC=1;AF=0.500;AN=2;RA=G;RR=GT,G GT 0/1
1 72 . T A,C 100 PASS AC=1,0;AF=0.500,0.00;AN=2;RA=A,C;RR=T,A,C GT 0/1
1 73 . TT AG 100 PASS AC=1;AF=0.500;AN=2;RA=AG;RR=TT,AG GT 0/1
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
##fileformat=VCFv4.2
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
##INFO=<ID=AC,Number=A,Type=Integer,Description="Allele count in genotypes, for each ALT allele, in the same order as listed">
##INFO=<ID=AC.new,Number=A,Type=Integer,Description="Original AC">
##INFO=<ID=AF,Number=A,Type=Float,Description="Allele Frequency, for each ALT allele, in the same order as listed">
##INFO=<ID=AF.new,Number=A,Type=Float,Description="Original AF">
##INFO=<ID=AN,Number=1,Type=Integer,Description="Total number of alleles in called genotypes">
##INFO=<ID=AN.new,Number=1,Type=Integer,Description="Original AN">
##INFO=<ID=RA,Number=A,Type=String,Description="Test ALT annotation">
##INFO=<ID=RR,Number=R,Type=String,Description="Test REF annotation">
##contig=<ID=1,length=16000>
##contig=<ID=2,length=16000>
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT Sample3
1 61 . GT G 724.43 PASS AC=1;AF=0.500;AN=2;RA=G;RR=GT,G GT 0/1
1 72 . T A,C 100 PASS AC=2,0;AF=1.00,0.00;AN=2;RA=A,C;RR=T,A,C GT 1/1
1 73 . TT AG 100 PASS AC=0;AF=0.00;AN=2;RA=AG;RR=TT,AG GT 0/0
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
##fileformat=VCFv4.2
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
##INFO=<ID=AC,Number=A,Type=Integer,Description="Allele count in genotypes, for each ALT allele, in the same order as listed">
##INFO=<ID=AC.new,Number=A,Type=Integer,Description="Original AC">
##INFO=<ID=AF,Number=A,Type=Float,Description="Allele Frequency, for each ALT allele, in the same order as listed">
##INFO=<ID=AF.new,Number=A,Type=Float,Description="Original AF">
##INFO=<ID=AN,Number=1,Type=Integer,Description="Total number of alleles in called genotypes">
##INFO=<ID=AN.new,Number=1,Type=Integer,Description="Original AN">
##INFO=<ID=RA,Number=A,Type=String,Description="Test ALT annotation">
##INFO=<ID=RR,Number=R,Type=String,Description="Test REF annotation">
##contig=<ID=1,length=16000>
##contig=<ID=2,length=16000>
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT Sample1
1 61 . GT G 724.43 PASS AC=1;AF=0.500;AN=2;RA=G;RR=GT,G GT 0/1
1 72 . T A 100 PASS AC=1;AF=0.500;AN=2;RA=A;RR=T,A GT 0/1
1 73 . TT AG 100 PASS AC=1;AF=0.500;AN=2;RA=AG;RR=TT,AG GT 0/1
Loading

0 comments on commit 5d9c678

Please sign in to comment.