Skip to content

Commit

Permalink
Rename ignore-zeros -> min-hits, allow higher ignore thresholds.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmaklin committed May 29, 2024
1 parent e868756 commit 59fe08c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Likelihood options:
Experimental options:
--run-rate Calculate relative reliability for each abundance estimate using RATE (default: false).
--ignore-zeros Ignore target clusters that did not have any reads align against them (default: false):
--min-hits Only consider target groups that have at least this many reads align to any sequence in them (default: 0).
```

# References
Expand Down
59 changes: 28 additions & 31 deletions include/Likelihood.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,33 +106,30 @@ class LL_WOR21 : public Likelihood<T> {
return ll_mat;
}

void fill_ll_mat(const telescope::Alignment &alignment, const std::vector<V> &group_sizes, const size_t n_groups, const bool mask_groups) {
void fill_ll_mat(const telescope::Alignment &alignment, const std::vector<V> &group_sizes, const size_t n_groups, const size_t min_hits) {
size_t num_ecs = alignment.n_ecs();

bool mask_groups = min_hits > 0;
this->groups_mask = std::vector<bool>(n_groups, !mask_groups);
std::vector<V> masked_group_sizes;
if (mask_groups) {
std::vector<size_t> group_hit_counts(n_groups, (size_t)0);
// Create mask identifying groups that have at least 1 alignment
for (size_t i = 0; i < num_ecs; ++i) {
for (size_t j = 0; j < n_groups; ++j) {
this->groups_mask[j] = groups_mask[j] || (alignment(j, i) > 0);
group_hit_counts[j] += (alignment(j, i) > 0);
}
}
}
size_t n_masked_groups = 0;
for (size_t i = 0; i < n_groups; ++i) {
n_masked_groups += groups_mask[i];
}

std::vector<V> masked_group_sizes;
if (mask_groups) {
for (size_t i = 0; i < n_groups; ++i) {
this->groups_mask[i] = groups_mask[i] || (group_hit_counts[i] >= min_hits);
if (this->groups_mask[i]) {
masked_group_sizes.push_back(group_sizes[i]);
}
}
} else {
masked_group_sizes = group_sizes;
}
size_t n_masked_groups = masked_group_sizes.size();

this->update_bb_parameters(masked_group_sizes, n_masked_groups, this->bb_constants);
const seamat::DenseMatrix<T> &precalc_lls_mat = this->precalc_lls(masked_group_sizes, n_masked_groups);
Expand Down Expand Up @@ -174,15 +171,15 @@ class LL_WOR21 : public Likelihood<T> {
public:
LL_WOR21() = default;

LL_WOR21(const std::vector<V> &group_sizes, const telescope::Alignment &alignment, const size_t n_groups, const T tol, const T frac_mu, const bool mask_groups, const T _zero_inflation) {
LL_WOR21(const std::vector<V> &group_sizes, const telescope::Alignment &alignment, const size_t n_groups, const T tol, const T frac_mu, const size_t min_hits, const T _zero_inflation) {
this->bb_constants[0] = tol;
this->bb_constants[1] = frac_mu;
this->zero_inflation = _zero_inflation;
this->from_grouped_alignment(alignment, group_sizes, n_groups, mask_groups);
this->from_grouped_alignment(alignment, group_sizes, n_groups, min_hits);
}

void from_grouped_alignment(const telescope::Alignment &alignment, const std::vector<V> &group_sizes, const size_t n_groups, const bool mask_groups) {
this->fill_ll_mat(alignment, group_sizes, n_groups, mask_groups);
void from_grouped_alignment(const telescope::Alignment &alignment, const std::vector<V> &group_sizes, const size_t n_groups, const size_t min_hits) {
this->fill_ll_mat(alignment, group_sizes, n_groups, min_hits);
this->fill_ec_counts(alignment);
}

Expand Down Expand Up @@ -296,49 +293,49 @@ class LL_WOR21 : public Likelihood<T> {
const std::vector<bool>& groups_considered() const override { return this->groups_mask; };
};
template <typename T>
std::unique_ptr<Likelihood<T>> ConstructAdaptiveLikelihood(const telescope::Alignment &alignment, const Grouping &grouping, const T q, const T e, const bool mask_groups, const T zero_inflation) {
std::unique_ptr<Likelihood<T>> ConstructAdaptiveLikelihood(const telescope::Alignment &alignment, const Grouping &grouping, const T q, const T e, const size_t min_hits, const T zero_inflation) {
size_t max_group_size = grouping.max_group_size();
size_t n_groups = grouping.get_n_groups();
std::unique_ptr<Likelihood<T>> log_likelihoods;
if (max_group_size <= std::numeric_limits<uint8_t>::max()) {
if (n_groups <= std::numeric_limits<uint8_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint16_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint32_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint8_t>(static_cast<const AdaptiveGrouping<uint8_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
}
} else if (max_group_size <= std::numeric_limits<uint16_t>::max()) {
if (n_groups <= std::numeric_limits<uint8_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint16_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint32_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint16_t>(static_cast<const AdaptiveGrouping<uint16_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
}
} else if (max_group_size <= std::numeric_limits<uint32_t>::max()) {
if (n_groups <= std::numeric_limits<uint8_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint16_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint32_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint32_t>(static_cast<const AdaptiveGrouping<uint32_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
}
} else {
if (n_groups <= std::numeric_limits<uint8_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint8_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint16_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint16_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else if (n_groups <= std::numeric_limits<uint32_t>::max()) {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint32_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
} else {
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation));
log_likelihoods.reset(new mSWEEP::LL_WOR21<T, uint64_t>(static_cast<const AdaptiveGrouping<uint64_t, uint64_t>*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation));
}
}
return log_likelihoods;
Expand Down
16 changes: 8 additions & 8 deletions src/mSWEEP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void parse_args(int argc, char* argv[], cxxargs::Arguments &args) {
args.set_not_required("alphas");

args.add_long_argument<bool>("run-rate", "Calculate relative reliability for each abundance estimate using RATE (default: false).", false);
args.add_long_argument<bool>("ignore-zeros", "Ignore target clusters that did not have any reads align against them (default: false).", false);
args.add_long_argument<size_t>("min-hits", "Only consider target groups that have at least this many reads align to any sequence in them (default: 0).", (size_t)0);

if (CmdOptionPresent(argv, argv+argc, "--help")) {
// Print help message and continue.
Expand Down Expand Up @@ -367,7 +367,7 @@ int main (int argc, char *argv[]) {

// Use the alignment data to populate the log_likelihoods matrix.
try {
log_likelihoods = mSWEEP::ConstructAdaptiveLikelihood<double>(*alignment, reference->get_grouping(i), args.value<double>('q'), args.value<double>('e'), args.value<bool>("ignore-zeros"), args.value<double>("zero-inflation"));
log_likelihoods = mSWEEP::ConstructAdaptiveLikelihood<double>(*alignment, reference->get_grouping(i), args.value<double>('q'), args.value<double>('e'), args.value<size_t>("min-hits"), args.value<double>("zero-inflation"));
} catch (std::exception &e) {
finalize("Building the log-likelihood array failed:\n " + std::string(e.what()) + "\nexiting\n", log, true);
return 1;
Expand Down Expand Up @@ -435,16 +435,16 @@ int main (int argc, char *argv[]) {
sample->dirichlet_kld(log_likelihoods->log_counts());
}

if (args.value<bool>("ignore-zeros")) {
std::cerr << "WARNING: --ignore-zeros is an experimental option that has not been thoroughly tested and is subject to change.\n" << std::endl;
if (args.value<size_t>("min-hits") > 0) {
std::cerr << "WARNING: --min-hits > 0 is an experimental option that has not been thoroughly tested and is subject to change.\n" << std::endl;
}

// Run binning if requested and write results to files.
if (rank == 0) { // root performs the rest.
// Turn the probs into relative abundances
sample->store_abundances(rcgpar::mixture_components(sample->get_probs(), log_likelihoods->log_counts()));

if (args.value<bool>("ignore-zeros")) {
if (args.value<size_t>("min-hits") > 0) {
for (size_t j = 0; j < reference->group_names(i).size(); ++j) {
if (log_likelihoods->groups_considered()[j]) {
estimated_reference_names.push_back(reference->group_names(i)[j]);
Expand Down Expand Up @@ -496,14 +496,14 @@ int main (int argc, char *argv[]) {
// Note: this ignores the printing_output variable because
// we might want to print the probs even when writing to
// pipe them somewhere.
if (args.value<bool>("ignore-zeros")) {
if (args.value<size_t>("min-hits") > 0) {
sample->write_probs2(estimated_reference_names, zero_reference_names, &std::cout);
} else {
sample->write_probs(estimated_reference_names, &std::cout);
}
}
if (args.value<bool>("write-probs")) {
if (args.value<bool>("ignore-zeros")) {
if (args.value<size_t>("min-hits") > 0) {
sample->write_probs2(estimated_reference_names, zero_reference_names, out.probs());
} else {
sample->write_probs(estimated_reference_names, out.probs());
Expand Down Expand Up @@ -568,7 +568,7 @@ int main (int argc, char *argv[]) {
throw std::runtime_error("Can't write to abundances file.");
}
} else {
if (args.value<bool>("ignore-zeros")) {
if (args.value<size_t>("min-hits") > 0) {
sample->write_abundances2(estimated_reference_names, zero_reference_names, out.abundances());
} else {
sample->write_abundances(estimated_reference_names, out.abundances());
Expand Down

0 comments on commit 59fe08c

Please sign in to comment.