diff --git a/README.md b/README.md index e4b052b..d7c0dbc 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ Likelihood options: Experimental options: --run-rate Calculate relative reliability for each abundance estimate using RATE (default: false). ---ignore-zeros Ignore target clusters that did not have any reads align against them (default: false): +--min-hits Only consider target groups that have at least this many reads align to any sequence in them (default: 0). ``` # References diff --git a/include/Likelihood.hpp b/include/Likelihood.hpp index 47d267b..2265395 100644 --- a/include/Likelihood.hpp +++ b/include/Likelihood.hpp @@ -106,26 +106,22 @@ class LL_WOR21 : public Likelihood { return ll_mat; } - void fill_ll_mat(const telescope::Alignment &alignment, const std::vector &group_sizes, const size_t n_groups, const bool mask_groups) { + void fill_ll_mat(const telescope::Alignment &alignment, const std::vector &group_sizes, const size_t n_groups, const size_t min_hits) { size_t num_ecs = alignment.n_ecs(); + bool mask_groups = min_hits > 0; this->groups_mask = std::vector(n_groups, !mask_groups); + std::vector masked_group_sizes; if (mask_groups) { + std::vector group_hit_counts(n_groups, (size_t)0); // Create mask identifying groups that have at least 1 alignment for (size_t i = 0; i < num_ecs; ++i) { for (size_t j = 0; j < n_groups; ++j) { - this->groups_mask[j] = groups_mask[j] || (alignment(j, i) > 0); + group_hit_counts[j] += (alignment(j, i) > 0); } } - } - size_t n_masked_groups = 0; - for (size_t i = 0; i < n_groups; ++i) { - n_masked_groups += groups_mask[i]; - } - - std::vector masked_group_sizes; - if (mask_groups) { for (size_t i = 0; i < n_groups; ++i) { + this->groups_mask[i] = groups_mask[i] || (group_hit_counts[i] >= min_hits); if (this->groups_mask[i]) { masked_group_sizes.push_back(group_sizes[i]); } @@ -133,6 +129,7 @@ class LL_WOR21 : public Likelihood { } else { masked_group_sizes = group_sizes; } + size_t n_masked_groups = masked_group_sizes.size(); this->update_bb_parameters(masked_group_sizes, n_masked_groups, this->bb_constants); const seamat::DenseMatrix &precalc_lls_mat = this->precalc_lls(masked_group_sizes, n_masked_groups); @@ -174,15 +171,15 @@ class LL_WOR21 : public Likelihood { public: LL_WOR21() = default; - LL_WOR21(const std::vector &group_sizes, const telescope::Alignment &alignment, const size_t n_groups, const T tol, const T frac_mu, const bool mask_groups, const T _zero_inflation) { + LL_WOR21(const std::vector &group_sizes, const telescope::Alignment &alignment, const size_t n_groups, const T tol, const T frac_mu, const size_t min_hits, const T _zero_inflation) { this->bb_constants[0] = tol; this->bb_constants[1] = frac_mu; this->zero_inflation = _zero_inflation; - this->from_grouped_alignment(alignment, group_sizes, n_groups, mask_groups); + this->from_grouped_alignment(alignment, group_sizes, n_groups, min_hits); } - void from_grouped_alignment(const telescope::Alignment &alignment, const std::vector &group_sizes, const size_t n_groups, const bool mask_groups) { - this->fill_ll_mat(alignment, group_sizes, n_groups, mask_groups); + void from_grouped_alignment(const telescope::Alignment &alignment, const std::vector &group_sizes, const size_t n_groups, const size_t min_hits) { + this->fill_ll_mat(alignment, group_sizes, n_groups, min_hits); this->fill_ec_counts(alignment); } @@ -296,49 +293,49 @@ class LL_WOR21 : public Likelihood { const std::vector& groups_considered() const override { return this->groups_mask; }; }; template -std::unique_ptr> ConstructAdaptiveLikelihood(const telescope::Alignment &alignment, const Grouping &grouping, const T q, const T e, const bool mask_groups, const T zero_inflation) { +std::unique_ptr> ConstructAdaptiveLikelihood(const telescope::Alignment &alignment, const Grouping &grouping, const T q, const T e, const size_t min_hits, const T zero_inflation) { size_t max_group_size = grouping.max_group_size(); size_t n_groups = grouping.get_n_groups(); std::unique_ptr> log_likelihoods; if (max_group_size <= std::numeric_limits::max()) { if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } } else if (max_group_size <= std::numeric_limits::max()) { if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } } else if (max_group_size <= std::numeric_limits::max()) { if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } } else { if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else if (n_groups <= std::numeric_limits::max()) { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } else { - log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, mask_groups, zero_inflation)); + log_likelihoods.reset(new mSWEEP::LL_WOR21(static_cast*>(&grouping)->get_sizes(), alignment, n_groups, q, e, min_hits, zero_inflation)); } } return log_likelihoods; diff --git a/src/mSWEEP.cpp b/src/mSWEEP.cpp index 1bc2927..04d666c 100644 --- a/src/mSWEEP.cpp +++ b/src/mSWEEP.cpp @@ -141,7 +141,7 @@ void parse_args(int argc, char* argv[], cxxargs::Arguments &args) { args.set_not_required("alphas"); args.add_long_argument("run-rate", "Calculate relative reliability for each abundance estimate using RATE (default: false).", false); - args.add_long_argument("ignore-zeros", "Ignore target clusters that did not have any reads align against them (default: false).", false); + args.add_long_argument("min-hits", "Only consider target groups that have at least this many reads align to any sequence in them (default: 0).", (size_t)0); if (CmdOptionPresent(argv, argv+argc, "--help")) { // Print help message and continue. @@ -367,7 +367,7 @@ int main (int argc, char *argv[]) { // Use the alignment data to populate the log_likelihoods matrix. try { - log_likelihoods = mSWEEP::ConstructAdaptiveLikelihood(*alignment, reference->get_grouping(i), args.value('q'), args.value('e'), args.value("ignore-zeros"), args.value("zero-inflation")); + log_likelihoods = mSWEEP::ConstructAdaptiveLikelihood(*alignment, reference->get_grouping(i), args.value('q'), args.value('e'), args.value("min-hits"), args.value("zero-inflation")); } catch (std::exception &e) { finalize("Building the log-likelihood array failed:\n " + std::string(e.what()) + "\nexiting\n", log, true); return 1; @@ -435,8 +435,8 @@ int main (int argc, char *argv[]) { sample->dirichlet_kld(log_likelihoods->log_counts()); } - if (args.value("ignore-zeros")) { - std::cerr << "WARNING: --ignore-zeros is an experimental option that has not been thoroughly tested and is subject to change.\n" << std::endl; + if (args.value("min-hits") > 0) { + std::cerr << "WARNING: --min-hits > 0 is an experimental option that has not been thoroughly tested and is subject to change.\n" << std::endl; } // Run binning if requested and write results to files. @@ -444,7 +444,7 @@ int main (int argc, char *argv[]) { // Turn the probs into relative abundances sample->store_abundances(rcgpar::mixture_components(sample->get_probs(), log_likelihoods->log_counts())); - if (args.value("ignore-zeros")) { + if (args.value("min-hits") > 0) { for (size_t j = 0; j < reference->group_names(i).size(); ++j) { if (log_likelihoods->groups_considered()[j]) { estimated_reference_names.push_back(reference->group_names(i)[j]); @@ -496,14 +496,14 @@ int main (int argc, char *argv[]) { // Note: this ignores the printing_output variable because // we might want to print the probs even when writing to // pipe them somewhere. - if (args.value("ignore-zeros")) { + if (args.value("min-hits") > 0) { sample->write_probs2(estimated_reference_names, zero_reference_names, &std::cout); } else { sample->write_probs(estimated_reference_names, &std::cout); } } if (args.value("write-probs")) { - if (args.value("ignore-zeros")) { + if (args.value("min-hits") > 0) { sample->write_probs2(estimated_reference_names, zero_reference_names, out.probs()); } else { sample->write_probs(estimated_reference_names, out.probs()); @@ -568,7 +568,7 @@ int main (int argc, char *argv[]) { throw std::runtime_error("Can't write to abundances file."); } } else { - if (args.value("ignore-zeros")) { + if (args.value("min-hits") > 0) { sample->write_abundances2(estimated_reference_names, zero_reference_names, out.abundances()); } else { sample->write_abundances(estimated_reference_names, out.abundances());