diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index a6a3eb15..30996815 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -3,7 +3,7 @@ #include #include "cursor/max_scored_cursor.hpp" -#include "query/queries.hpp" +#include "query.hpp" #include "scorer/index_scorer.hpp" #include "util/compiler_attribute.hpp" @@ -30,7 +30,7 @@ class BlockMaxScoredCursor: public MaxScoredCursor { ~BlockMaxScoredCursor() = default; [[nodiscard]] PISA_ALWAYSINLINE auto block_max_score() -> float { - return m_wdata.score() * this->query_weight(); + return m_wdata.score() * this->weight(); } [[nodiscard]] PISA_ALWAYSINLINE auto block_max_docid() -> std::uint32_t { @@ -45,38 +45,21 @@ class BlockMaxScoredCursor: public MaxScoredCursor { template [[nodiscard]] auto make_block_max_scored_cursors( - Index const& index, WandType const& wdata, Scorer const& scorer, Query query, bool weighted = false + Index const& index, WandType const& wdata, Scorer const& scorer, Query const& query, bool weighted = false ) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); - std::vector> cursors; - cursors.reserve(query_term_freqs.size()); + cursors.reserve(query.terms().size()); std::transform( - query_term_freqs.begin(), - query_term_freqs.end(), + query.terms().begin(), + query.terms().end(), std::back_inserter(cursors), - [&](auto&& term) { - auto term_weight = 1.0F; - auto term_id = term.first; - auto max_weight = wdata.max_term_weight(term_id); - - if (weighted) { - term_weight = term.second; - max_weight = term_weight * max_weight; - return BlockMaxScoredCursor( - index[term_id], - [scorer = scorer.term_scorer(term_id), weight = term_weight]( - uint32_t doc, uint32_t freq - ) { return weight * scorer(doc, freq); }, - term_weight, - max_weight, - wdata.getenum(term_id) - ); - } - + [&](WeightedTerm const& term) { return BlockMaxScoredCursor( - index[term_id], scorer.term_scorer(term_id), term_weight, max_weight, wdata.getenum(term_id) + index[term.id], + scorer.term_scorer(term.id), + weighted ? term.weight : 1.0F, + wdata.max_term_weight(term.id), + wdata.getenum(term.id) ); } ); diff --git a/include/pisa/cursor/cursor.hpp b/include/pisa/cursor/cursor.hpp index 7703a268..1feccd35 100644 --- a/include/pisa/cursor/cursor.hpp +++ b/include/pisa/cursor/cursor.hpp @@ -1,21 +1,28 @@ #pragma once -#include "query/queries.hpp" #include +#include "query.hpp" + namespace pisa { +/** Creates cursors for query terms. + * + * These are frequency-only cursors. If you want to calculate scores, use + * `make_scored_cursors`, `make_max_scored_cursors`, or `make_block_max_scored_cursors`. + */ template [[nodiscard]] auto make_cursors(Index const& index, Query query) { - auto terms = query.terms; - remove_duplicate_terms(terms); using cursor = typename Index::document_enumerator; std::vector cursors; - cursors.reserve(terms.size()); - std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto&& term) { - return index[term]; - }); + cursors.reserve(query.terms().size()); + std::transform( + query.terms().begin(), + query.terms().end(), + std::back_inserter(cursors), + [&](auto&& term) { return index[term.id]; } + ); return cursors; } diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index efacf001..af05ef2c 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -3,7 +3,7 @@ #include #include "cursor/scored_cursor.hpp" -#include "query/queries.hpp" +#include "query.hpp" #include "util/compiler_attribute.hpp" namespace pisa { @@ -13,8 +13,8 @@ class MaxScoredCursor: public ScoredCursor { public: using base_cursor_type = Cursor; - MaxScoredCursor(Cursor cursor, TermScorer term_scorer, float query_weight, float max_score) - : ScoredCursor(std::move(cursor), std::move(term_scorer), query_weight), + MaxScoredCursor(Cursor cursor, TermScorer term_scorer, float weight, float max_score) + : ScoredCursor(std::move(cursor), std::move(term_scorer), weight), m_max_score(max_score) {} MaxScoredCursor(MaxScoredCursor const&) = delete; MaxScoredCursor(MaxScoredCursor&&) = default; @@ -22,7 +22,9 @@ class MaxScoredCursor: public ScoredCursor { MaxScoredCursor& operator=(MaxScoredCursor&&) = default; ~MaxScoredCursor() = default; - [[nodiscard]] PISA_ALWAYSINLINE auto max_score() const noexcept -> float { return m_max_score; } + [[nodiscard]] PISA_ALWAYSINLINE auto max_score() const noexcept -> float { + return this->weight() * m_max_score; + } private: float m_max_score; @@ -30,37 +32,20 @@ class MaxScoredCursor: public ScoredCursor { template [[nodiscard]] auto make_max_scored_cursors( - Index const& index, WandType const& wdata, Scorer const& scorer, Query query, bool weighted = false + Index const& index, WandType const& wdata, Scorer const& scorer, Query const& query, bool weighted = false ) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); - std::vector> cursors; - cursors.reserve(query_term_freqs.size()); + cursors.reserve(query.terms().size()); std::transform( - query_term_freqs.begin(), - query_term_freqs.end(), + query.terms().begin(), + query.terms().end(), std::back_inserter(cursors), - [&](auto&& term) { - auto term_weight = 1.0F; - auto term_id = term.first; - auto max_weight = wdata.max_term_weight(term_id); - - if (weighted) { - term_weight = term.second; - max_weight = term_weight * max_weight; - return MaxScoredCursor( - index[term_id], - [scorer = scorer.term_scorer(term_id), weight = term_weight]( - uint32_t doc, uint32_t freq - ) { return weight * scorer(doc, freq); }, - term_weight, - max_weight - ); - } - + [&](WeightedTerm const& term) { return MaxScoredCursor( - index[term_id], scorer.term_scorer(term_id), term_weight, max_weight + index[term.id], + scorer.term_scorer(term.id), + weighted ? term.weight : 1.0F, + wdata.max_term_weight(term.id) ); } ); diff --git a/include/pisa/cursor/scored_cursor.hpp b/include/pisa/cursor/scored_cursor.hpp index b0671338..8e3c797a 100644 --- a/include/pisa/cursor/scored_cursor.hpp +++ b/include/pisa/cursor/scored_cursor.hpp @@ -2,30 +2,37 @@ #include -#include "query/queries.hpp" +#include "query.hpp" #include "scorer/index_scorer.hpp" #include "util/compiler_attribute.hpp" namespace pisa { +template +auto resolve_term_scorer(Scorer scorer, float weight) -> TermScorer { + if (weight == 1.0F) { + // Optimization: no multiplication necessary if weight is 1.0 + return scorer; + } + return [scorer, weight](uint32_t doc, uint32_t freq) { return weight * scorer(doc, freq); }; +} + template class ScoredCursor { public: using base_cursor_type = Cursor; - ScoredCursor(Cursor cursor, TermScorer term_scorer, float query_weight) + ScoredCursor(Cursor cursor, TermScorer term_scorer, float weight) : m_base_cursor(std::move(cursor)), - m_term_scorer(std::move(term_scorer)), - m_query_weight(query_weight) {} + m_weight(weight), + m_term_scorer(resolve_term_scorer(term_scorer, weight)) {} ScoredCursor(ScoredCursor const&) = delete; ScoredCursor(ScoredCursor&&) = default; ScoredCursor& operator=(ScoredCursor const&) = delete; ScoredCursor& operator=(ScoredCursor&&) = default; ~ScoredCursor() = default; - [[nodiscard]] PISA_ALWAYSINLINE auto query_weight() const noexcept -> float { - return m_query_weight; - } + [[nodiscard]] PISA_ALWAYSINLINE auto weight() const noexcept -> float { return m_weight; } [[nodiscard]] PISA_ALWAYSINLINE auto docid() const -> std::uint32_t { return m_base_cursor.docid(); } @@ -37,38 +44,23 @@ class ScoredCursor { private: Cursor m_base_cursor; + float m_weight = 1.0; TermScorer m_term_scorer; - float m_query_weight = 1.0; }; template -[[nodiscard]] auto -make_scored_cursors(Index const& index, Scorer const& scorer, Query query, bool weighted = false) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); - +[[nodiscard]] auto make_scored_cursors( + Index const& index, Scorer const& scorer, Query const& query, bool weighted = false +) { std::vector> cursors; - cursors.reserve(query_term_freqs.size()); + cursors.reserve(query.terms().size()); std::transform( - query_term_freqs.begin(), - query_term_freqs.end(), + query.terms().begin(), + query.terms().end(), std::back_inserter(cursors), - [&](auto&& term) { - auto term_weight = 1.0F; - auto term_id = term.first; - - if (weighted) { - term_weight = term.second; - return ScoredCursor( - index[term_id], - [scorer = scorer.term_scorer(term_id), weight = term_weight]( - uint32_t doc, uint32_t freq - ) { return weight * scorer(doc, freq); }, - term_weight - ); - } + [&](WeightedTerm const& term) { return ScoredCursor( - index[term_id], scorer.term_scorer(term_id), term_weight + index[term.id], scorer.term_scorer(term.id), weighted ? term.weight : 1.0 ); } ); diff --git a/include/pisa/forward_index_builder.hpp b/include/pisa/forward_index_builder.hpp index f98e6cc4..49e1e733 100644 --- a/include/pisa/forward_index_builder.hpp +++ b/include/pisa/forward_index_builder.hpp @@ -1,14 +1,11 @@ #pragma once -#include #include #include #include #include #include "document_record.hpp" -#include "forward_index_builder.hpp" -#include "query/term_processor.hpp" #include "text_analyzer.hpp" #include "type_safe.hpp" diff --git a/include/pisa/intersection.hpp b/include/pisa/intersection.hpp index 9ce2f42d..18ed7258 100644 --- a/include/pisa/intersection.hpp +++ b/include/pisa/intersection.hpp @@ -1,11 +1,11 @@ #include #include #include +#include #include #include "cursor/scored_cursor.hpp" #include "query/algorithm/and_query.hpp" -#include "query/queries.hpp" #include "scorer/scorer.hpp" namespace pisa { @@ -24,20 +24,23 @@ namespace intersection { /// Returns a filtered copy of `query` containing only terms indicated by ones in the bit mask. [[nodiscard]] inline auto filter(Query const& query, Mask const& mask) -> Query { - if (query.terms.size() > MAX_QUERY_LEN) { + if (query.terms().size() > MAX_QUERY_LEN) { throw std::invalid_argument("Queries can be at most 2^32 terms long"); } std::vector terms; std::vector weights; - for (std::size_t bitpos = 0; bitpos < query.terms.size(); ++bitpos) { + for (std::size_t bitpos = 0; bitpos < query.terms().size(); ++bitpos) { if (((1U << bitpos) & mask.to_ulong()) > 0) { - terms.push_back(query.terms.at(bitpos)); - if (bitpos < query.term_weights.size()) { - weights.push_back(query.term_weights[bitpos]); - } + auto term = query.terms().at(bitpos); + terms.push_back(term.id); + weights.push_back(term.weight); } } - return Query{query.id, terms, weights}; + return Query( + query.id().has_value() ? std::make_optional(*query.id()) : std::nullopt, + terms, + weights + ); } } // namespace intersection @@ -80,7 +83,7 @@ inline auto Intersection::compute( /// `Fn` takes `Query` and `Mask`. template auto for_all_subsets(Query const& query, std::optional max_term_count, Fn func) { - auto subset_count = 1U << query.terms.size(); + auto subset_count = 1U << query.terms().size(); for (auto subset = 1U; subset < subset_count; ++subset) { auto mask = intersection::Mask(subset); if (!max_term_count || mask.count() <= *max_term_count) { diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp new file mode 100644 index 00000000..4685eaff --- /dev/null +++ b/include/pisa/query.hpp @@ -0,0 +1,173 @@ +// Copyright 2024 PISA Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "type_alias.hpp" + +namespace pisa { + +/** Term ID along with its weight. + * + * Typically, a weight would be equal to the number of occurrences of the term in a query. + * Partial scores coming from this term will be multiplied by this weight. + */ +struct WeightedTerm { + TermId id; + Score weight; + + /** Tuple conversion for structured bindings support. + * + * ``` + * WeightedTerm wt{0, 1.0}; + * auto [term_id, weight] = weighted_term; + * ``` + */ + [[nodiscard]] explicit constexpr operator std::pair() const noexcept; +}; + +[[nodiscard]] inline auto operator==(WeightedTerm const& lhs, WeightedTerm const& rhs) -> bool { + return std::tie(lhs.id, lhs.weight) == std::tie(rhs.id, rhs.weight); +} + +namespace query { + + /** + * Tells `Query` how to process the terms passed to the constructor. + * + * By default, duplicate terms will be removed, and the weight of each term will be equal to + * the number of occurrences of that term in the query. Furthermore, the order of the terms + * will be preserved (if there are duplicates, the term will be at the position of its first + * occurrence. + * + * This policy can be modified with the following options: + * - `keep_duplicates`: duplicates will be preserved, each with weight 1.0; + * - `unweighted`: forces each weight to be 1.0 even if duplicates are removed; + * - `sort`: sorts terms by ID. + * + * Policies can be combined similar to bitsets. For example, `unweighted | sort` will both + * force unit weights and sort the terms. + */ + struct TermPolicy { + std::uint32_t policy; + + /** Checks if this policy contains the other policy. */ + [[nodiscard]] constexpr auto contains(TermPolicy const& other) const noexcept -> bool; + }; + + /** Merges two policies; the resulting policy will policies from both arguments. */ + [[nodiscard]] auto operator|(TermPolicy lhs, TermPolicy rhs) noexcept -> TermPolicy; + + /** Keep duplicates. */ + static constexpr TermPolicy default_policy = {0b000}; + + /** Keep duplicates. */ + static constexpr TermPolicy keep_duplicates = {0b001}; + + /** Use weight 1.0 for each resulting term. */ + static constexpr TermPolicy unweighted = {0b010}; + + /** Sort by term ID. */ + static constexpr TermPolicy sort = {0b100}; + +} // namespace query + +/** + * A query issued to the system. + */ +class Query { + std::optional m_id{}; + std::vector m_terms{}; + + void postprocess(query::TermPolicy policy); + + public: + /** Constructs a query with the given ID from the terms and weights given by the iterators. + */ + template + Query( + std::optional id, + TermIterator first_term, + TermIterator last_term, + WeightIterator first_weight, + query::TermPolicy policy = query::default_policy + ) + : m_id(std::move(id)) { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_constructible_v); + std::transform( + first_term, + last_term, + first_weight, + std::back_inserter(m_terms), + [](auto id, auto weight) { + return WeightedTerm{id, weight}; + } + ); + postprocess(policy); + } + + /** Constructs a query with the given ID from the terms given by the iterators. */ + template + Query( + std::optional id, + Iterator first, + Iterator last, + query::TermPolicy policy = query::default_policy + ) + : m_id(std::move(id)) { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_constructible_v); + std::transform(first, last, std::back_inserter(m_terms), [](auto id) { + return WeightedTerm{id, 1.0}; + }); + postprocess(policy); + } + + /** Constructs a query with the given ID from the terms and weights passed as collections. + */ + template + explicit Query( + std::optional id, + Terms const& terms, + Weights const& weights, + query::TermPolicy policy = {0} + ) + : Query(std::move(id), terms.begin(), terms.end(), weights.begin(), policy) {} + + /** Constructs a query with the given ID from the terms passed as a collection. */ + template + explicit Query( + std::optional id, + Collection const& terms, + query::TermPolicy policy = query::default_policy + ) + : Query(std::move(id), terms.begin(), terms.end(), policy) {} + + /** Returns the ID of the query if defined. */ + [[nodiscard]] auto id() const noexcept -> std::optional; + + /** Returns the reference to all weighted terms. */ + [[nodiscard]] auto terms() const noexcept -> std::vector const&; +}; + +} // namespace pisa diff --git a/include/pisa/query/algorithm/and_query.hpp b/include/pisa/query/algorithm/and_query.hpp index 7d2077d1..9f8ea1f4 100644 --- a/include/pisa/query/algorithm/and_query.hpp +++ b/include/pisa/query/algorithm/and_query.hpp @@ -4,9 +4,6 @@ #include #include -#include "query/queries.hpp" -#include "util/do_not_optimize_away.hpp" - namespace pisa { /** diff --git a/include/pisa/query/algorithm/block_max_maxscore_query.hpp b/include/pisa/query/algorithm/block_max_maxscore_query.hpp index c247d4ae..1dcc036c 100644 --- a/include/pisa/query/algorithm/block_max_maxscore_query.hpp +++ b/include/pisa/query/algorithm/block_max_maxscore_query.hpp @@ -2,7 +2,6 @@ #include -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/block_max_ranked_and_query.hpp b/include/pisa/query/algorithm/block_max_ranked_and_query.hpp index 71c59998..6f5cd9d7 100644 --- a/include/pisa/query/algorithm/block_max_ranked_and_query.hpp +++ b/include/pisa/query/algorithm/block_max_ranked_and_query.hpp @@ -1,9 +1,9 @@ #pragma once -#include "query/queries.hpp" -#include "topk_queue.hpp" #include +#include "topk_queue.hpp" + namespace pisa { struct block_max_ranked_and_query { diff --git a/include/pisa/query/algorithm/block_max_wand_query.hpp b/include/pisa/query/algorithm/block_max_wand_query.hpp index 6f081767..fa43dd93 100644 --- a/include/pisa/query/algorithm/block_max_wand_query.hpp +++ b/include/pisa/query/algorithm/block_max_wand_query.hpp @@ -1,8 +1,9 @@ #pragma once -#include "query/queries.hpp" -#include "topk_queue.hpp" #include + +#include "topk_queue.hpp" + namespace pisa { struct block_max_wand_query { diff --git a/include/pisa/query/algorithm/maxscore_query.hpp b/include/pisa/query/algorithm/maxscore_query.hpp index 6d575b4e..5b127e82 100644 --- a/include/pisa/query/algorithm/maxscore_query.hpp +++ b/include/pisa/query/algorithm/maxscore_query.hpp @@ -2,9 +2,9 @@ #include #include +#include #include -#include "query/queries.hpp" #include "topk_queue.hpp" #include "util/compiler_attribute.hpp" diff --git a/include/pisa/query/algorithm/or_query.hpp b/include/pisa/query/algorithm/or_query.hpp index 465d21ab..41604eb7 100644 --- a/include/pisa/query/algorithm/or_query.hpp +++ b/include/pisa/query/algorithm/or_query.hpp @@ -1,7 +1,10 @@ #pragma once -#include "query/queries.hpp" -#include +#include +#include +#include + +#include "util/do_not_optimize_away.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/range_query.hpp b/include/pisa/query/algorithm/range_query.hpp index c32907b5..8b9ae280 100644 --- a/include/pisa/query/algorithm/range_query.hpp +++ b/include/pisa/query/algorithm/range_query.hpp @@ -1,6 +1,5 @@ #pragma once -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/range_taat_query.hpp b/include/pisa/query/algorithm/range_taat_query.hpp index eaee9f5a..cd4aa774 100644 --- a/include/pisa/query/algorithm/range_taat_query.hpp +++ b/include/pisa/query/algorithm/range_taat_query.hpp @@ -2,7 +2,6 @@ #include "accumulator/partial_score_accumulator.hpp" #include "concepts.hpp" -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/ranked_and_query.hpp b/include/pisa/query/algorithm/ranked_and_query.hpp index 4c4e9d39..318ca9f0 100644 --- a/include/pisa/query/algorithm/ranked_and_query.hpp +++ b/include/pisa/query/algorithm/ranked_and_query.hpp @@ -1,9 +1,9 @@ #pragma once -#include "query/queries.hpp" -#include "topk_queue.hpp" #include +#include "topk_queue.hpp" + namespace pisa { struct ranked_and_query { diff --git a/include/pisa/query/algorithm/ranked_or_query.hpp b/include/pisa/query/algorithm/ranked_or_query.hpp index 3fd5eef7..f10583e3 100644 --- a/include/pisa/query/algorithm/ranked_or_query.hpp +++ b/include/pisa/query/algorithm/ranked_or_query.hpp @@ -1,9 +1,7 @@ #pragma once -#include #include -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/ranked_or_taat_query.hpp b/include/pisa/query/algorithm/ranked_or_taat_query.hpp index f2338596..e7c0bced 100644 --- a/include/pisa/query/algorithm/ranked_or_taat_query.hpp +++ b/include/pisa/query/algorithm/ranked_or_taat_query.hpp @@ -2,7 +2,6 @@ #include "accumulator/partial_score_accumulator.hpp" #include "concepts.hpp" -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/algorithm/wand_query.hpp b/include/pisa/query/algorithm/wand_query.hpp index dd610cd1..4e7b9527 100644 --- a/include/pisa/query/algorithm/wand_query.hpp +++ b/include/pisa/query/algorithm/wand_query.hpp @@ -2,7 +2,6 @@ #include -#include "query/queries.hpp" #include "topk_queue.hpp" namespace pisa { diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp deleted file mode 100644 index 3c804f9b..00000000 --- a/include/pisa/query/queries.hpp +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "query/term_processor.hpp" -#include "tokenizer.hpp" - -namespace pisa { - -using term_id_type = uint32_t; -using term_id_vec = std::vector; -using term_freq_pair = std::pair; -using term_freq_vec = std::vector; - -struct Query { - std::optional id; - std::vector terms; - std::vector term_weights; -}; - -[[nodiscard]] auto split_query_at_colon(std::string_view query_string) - -> std::pair, std::string_view>; - -[[nodiscard]] auto parse_query_terms( - std::string const& query_string, Tokenizer const& tokenizer, TermProcessor term_processor -) -> Query; - -[[nodiscard]] auto parse_query_ids(std::string const& query_string) -> Query; - -[[nodiscard]] std::function resolve_query_parser( - std::vector& queries, - std::unique_ptr tokenizer, - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type -); - -bool read_query(term_id_vec& ret, std::istream& is = std::cin); - -void remove_duplicate_terms(term_id_vec& terms); - -term_freq_vec query_freqs(term_id_vec terms); - -} // namespace pisa diff --git a/include/pisa/query/query_parser.hpp b/include/pisa/query/query_parser.hpp index b41ee8bf..21c7103e 100644 --- a/include/pisa/query/query_parser.hpp +++ b/include/pisa/query/query_parser.hpp @@ -1,17 +1,30 @@ #pragma once -#include "query/queries.hpp" +#include "query.hpp" #include "term_map.hpp" #include "text_analyzer.hpp" namespace pisa { +/** Query parser. + * + * Parses a string and maps tokens to term IDs. + */ class QueryParser { TextAnalyzer m_analyzer; std::unique_ptr m_term_map; public: - explicit QueryParser(TextAnalyzer analyzer, std::unique_ptr term_map = nullptr); + /** Constructs a parser. + * + * If term map is not passed, then each token will be parsed as a number and treated as + * term ID. + */ + explicit QueryParser(TextAnalyzer analyzer, std::unique_ptr term_map); + + /** Constructs a parser with `IntMap`, which parses numbers to term IDs. */ + explicit QueryParser(TextAnalyzer analyzer); + [[nodiscard]] auto parse(std::string_view query) -> Query; [[nodiscard]] auto parse(std::string const& query) -> Query; }; diff --git a/include/pisa/query/query_stemmer.hpp b/include/pisa/query/query_stemmer.hpp deleted file mode 100644 index ad817c62..00000000 --- a/include/pisa/query/query_stemmer.hpp +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once -#include -#include -#include - -#include - -#include "query/queries.hpp" -#include "query/term_processor.hpp" -#include "tokenizer.hpp" - -namespace pisa { - -class QueryStemmer { - public: - explicit QueryStemmer(std::optional const& stemmer_name) - : m_stemmer(term_transformer_builder(stemmer_name)()) {} - std::string operator()(std::string const& query_string) { - std::stringstream tokenized_query; - auto [id, raw_query] = split_query_at_colon(query_string); - std::vector stemmed_terms; - EnglishTokenStream tokenizer(raw_query); - for (auto token: tokenizer) { - stemmed_terms.push_back(m_stemmer(token)); - } - if (id) { - tokenized_query << *(id) << ":"; - } - using boost::algorithm::join; - tokenized_query << join(stemmed_terms, " "); - return tokenized_query.str(); - } - - TermTransformer m_stemmer; -}; - -} // namespace pisa diff --git a/include/pisa/query/term_processor.hpp b/include/pisa/query/term_processor.hpp deleted file mode 100644 index f1716285..00000000 --- a/include/pisa/query/term_processor.hpp +++ /dev/null @@ -1,64 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "io.hpp" -#include "memory_source.hpp" -#include "payload_vector.hpp" - -namespace pisa { - -using term_id_type = uint32_t; -using TermTransformer = std::function; -using TermTransformerBuilder = std::function; - -auto term_transformer_builder(std::optional const& type) -> TermTransformerBuilder; - -class TermProcessor { - private: - std::unordered_set stopwords; - - // Method implemented in constructor according to the specified stemmer. - std::function(std::string)> _to_id; - - public: - TermProcessor( - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type - ) { - auto source = std::make_shared(MemorySource::mapped_file(*terms_file)); - auto terms = Payload_Vector<>::from(*source); - auto to_id = [source = std::move(source), terms](auto str) -> std::optional { - // Note: the lexicographical order of the terms matters. - return pisa::binary_search(terms.begin(), terms.end(), std::string_view(str)); - }; - - // Implements '_to_id' method. - _to_id = [=](auto str) { return to_id(term_transformer_builder(stemmer_type)()(str)); }; - // Loads stopwords. - if (stopwords_filename) { - std::ifstream is(*stopwords_filename); - io::for_each_line(is, [&](auto&& word) { - if (auto processed_term = _to_id(std::move(word)); processed_term.has_value()) { - stopwords.insert(*processed_term); - } - }); - } - } - - std::optional operator()(std::string token) { return _to_id(token); } - - bool is_stopword(const term_id_type term) { return stopwords.find(term) != stopwords.end(); } - - std::vector get_stopwords() { - std::vector v; - v.insert(v.end(), stopwords.begin(), stopwords.end()); - sort(v.begin(), v.end()); - return v; - } -}; - -} // namespace pisa diff --git a/include/pisa/string.hpp b/include/pisa/string.hpp new file mode 100644 index 00000000..b403c7d3 --- /dev/null +++ b/include/pisa/string.hpp @@ -0,0 +1,26 @@ +// Copyright 2024 PISA Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace pisa { + +/** Splits the given string at a colon. */ +[[nodiscard]] auto split_at_colon(std::string_view str) + -> std::pair, std::string_view>; + +} // namespace pisa diff --git a/include/pisa/taily_stats.hpp b/include/pisa/taily_stats.hpp index 05d92371..c8bf90e2 100644 --- a/include/pisa/taily_stats.hpp +++ b/include/pisa/taily_stats.hpp @@ -1,23 +1,17 @@ #pragma once -#include -#include - #include #include #include #include "binary_freq_collection.hpp" #include "memory_source.hpp" -#include "query/queries.hpp" -#include "scorer/scorer.hpp" +#include "query.hpp" #include "timer.hpp" #include "type_safe.hpp" +#include "util/compiler_attribute.hpp" #include "util/progress.hpp" #include "vec_map.hpp" -#include "wand_data.hpp" -#include "wand_data_compressed.hpp" -#include "wand_data_raw.hpp" namespace pisa { @@ -31,20 +25,20 @@ class TailyStats { [[nodiscard]] auto num_documents() const -> std::uint64_t { return read_at(0); } [[nodiscard]] auto num_terms() const -> std::uint64_t { return read_at(8); } - [[nodiscard]] auto term_stats(term_id_type term_id) const -> taily::Feature_Statistics { + [[nodiscard]] auto term_stats(TermId term_id) const -> taily::Feature_Statistics { std::size_t offset = 16 + term_id * 24; auto expected_value = read_at(offset); auto variance = read_at(offset + sizeof(double)); auto frequency = read_at(offset + 2 * sizeof(double)); return taily::Feature_Statistics{expected_value, variance, frequency}; } - [[nodiscard]] auto query_stats(pisa::Query const& query) const -> taily::Query_Statistics { + [[nodiscard]] auto query_stats(Query const& query) const -> taily::Query_Statistics { std::vector stats; std::transform( - query.terms.begin(), - query.terms.end(), + query.terms().begin(), + query.terms().end(), std::back_inserter(stats), - [this](auto&& term_id) { return this->term_stats(term_id); } + [this](auto&& term) { return this->term_stats(term.id); } ); return taily::Query_Statistics{std::move(stats), static_cast(num_documents())}; } @@ -121,8 +115,8 @@ template void taily_score_shards( std::string const& global_stats_path, VecMap const& shard_stats_paths, - std::vector<::pisa::Query> const& global_queries, - VecMap> const& shard_queries, + std::vector const& global_queries, + VecMap> const& shard_queries, std::size_t k, Fn func ) { diff --git a/include/pisa/token_filter.hpp b/include/pisa/token_filter.hpp index ad98d12f..0aa4efb0 100644 --- a/include/pisa/token_filter.hpp +++ b/include/pisa/token_filter.hpp @@ -73,4 +73,6 @@ class StopWordRemover final: public TokenFilter { [[nodiscard]] auto filter(CowString input) const -> std::unique_ptr override; }; +[[nodiscard]] auto stemmer_from_name(std::string_view name) -> std::unique_ptr; + } // namespace pisa diff --git a/include/pisa/type_alias.hpp b/include/pisa/type_alias.hpp index 48700cd8..109b26ae 100644 --- a/include/pisa/type_alias.hpp +++ b/include/pisa/type_alias.hpp @@ -1,3 +1,17 @@ +// Copyright 2024 PISA Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include @@ -5,6 +19,7 @@ namespace pisa { using DocId = std::uint32_t; +using TermId = std::uint32_t; using Score = float; } // namespace pisa diff --git a/src/forward_index_builder.cpp b/src/forward_index_builder.cpp index 1f7df326..5dd59450 100644 --- a/src/forward_index_builder.cpp +++ b/src/forward_index_builder.cpp @@ -3,13 +3,16 @@ #include #include +#include #include #include #include #include #include -#include "binary_collection.hpp" +#include "pisa/binary_collection.hpp" +#include "pisa/io.hpp" +#include "pisa/payload_vector.hpp" namespace pisa { diff --git a/src/query.cpp b/src/query.cpp new file mode 100644 index 00000000..c23e40d1 --- /dev/null +++ b/src/query.cpp @@ -0,0 +1,83 @@ +// Copyright 2024 PISA Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "query.hpp" + +#include +#include + +namespace pisa { + +namespace query { + + constexpr auto TermPolicy::contains(TermPolicy const& other) const noexcept -> bool { + return (this->policy & other.policy) > 0; + } + + auto operator|(query::TermPolicy lhs, query::TermPolicy rhs) noexcept -> query::TermPolicy { + return query::TermPolicy{lhs.policy | rhs.policy}; + } + +} // namespace query + +auto Query::id() const noexcept -> std::optional { + if (m_id) { + return std::string_view(*m_id); + } + return std::nullopt; +} + +auto Query::terms() const noexcept -> std::vector const& { + return m_terms; +} + +/** The first occurrence of each term is assigned the accumulated weight. */ +void accumualte_weights(std::vector& terms) { + std::unordered_map positions; + for (auto it = terms.begin(); it != terms.end(); ++it) { + if (auto pos = positions.find(it->id); pos != positions.end()) { + pos->second->weight += 1.0; + } else { + positions[it->id] = it; + } + } +} + +void dedup_by_term_id(std::vector& terms) { + auto out = terms.begin(); + std::unordered_set seen_terms; + for (auto const& [term_id, weight]: terms) { + if (seen_terms.find(term_id) == seen_terms.end()) { + *out++ = {term_id, weight}; + seen_terms.insert(term_id); + } + } + terms.erase(out, terms.end()); +} + +void Query::postprocess(query::TermPolicy policy) { + if (!policy.contains(query::keep_duplicates)) { + if (!policy.contains(query::unweighted)) { + accumualte_weights(m_terms); + } + dedup_by_term_id(m_terms); + } + if (policy.contains(query::sort)) { + std::sort(m_terms.begin(), m_terms.end(), [](auto const& lhs, auto const& rhs) { + return lhs.id < rhs.id; + }); + } +} + +} // namespace pisa diff --git a/src/query/queries.cpp b/src/query/queries.cpp deleted file mode 100644 index 6d763271..00000000 --- a/src/query/queries.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include "query/queries.hpp" - -#include -#include -#include - -#include "index_types.hpp" -#include "tokenizer.hpp" -#include "topk_queue.hpp" -#include "util/util.hpp" - -namespace pisa { - -auto split_query_at_colon(std::string_view query_string) - -> std::pair, std::string_view> { - // query id : terms (or ids) - auto colon = std::find(query_string.begin(), query_string.end(), ':'); - std::optional id; - if (colon != query_string.end()) { - id = std::string(query_string.begin(), colon); - } - auto pos = colon == query_string.end() ? query_string.begin() : std::next(colon); - auto raw_query = std::string_view(&*pos, std::distance(pos, query_string.end())); - return {std::move(id), raw_query}; -} - -auto parse_query_terms( - std::string const& query_string, Tokenizer const& tokenizer, TermProcessor term_processor -) -> Query { - auto [id, raw_query] = split_query_at_colon(query_string); - auto tokens = tokenizer.tokenize(raw_query); - std::vector parsed_query; - for (auto raw_term: *tokens) { - auto term = term_processor(raw_term); - if (term) { - if (!term_processor.is_stopword(*term)) { - parsed_query.push_back(*term); - } else { - spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); - } - } else { - spdlog::warn("Term `{}` not found and will be ignored", raw_term); - } - } - return {std::move(id), std::move(parsed_query), {}}; -} - -auto parse_query_ids(std::string const& query_string) -> Query { - auto [id, raw_query] = split_query_at_colon(query_string); - std::vector parsed_query; - std::vector term_ids; - boost::split(term_ids, raw_query, boost::is_any_of("\t, ,\v,\f,\r,\n")); - - auto is_empty = [](const std::string& val) { return val.empty(); }; - // remove_if move matching elements to the end, preparing them for erase. - term_ids.erase(std::remove_if(term_ids.begin(), term_ids.end(), is_empty), term_ids.end()); - - try { - auto to_int = [](const std::string& val) { return std::stoi(val); }; - std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(parsed_query), to_int); - } catch (std::invalid_argument& err) { - spdlog::error("Could not parse term identifiers of query `{}`", raw_query); - exit(1); - } - return {std::move(id), std::move(parsed_query), {}}; -} - -std::function resolve_query_parser( - std::vector& queries, - std::unique_ptr tokenizer, - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type -) { - if (terms_file) { - auto term_processor = TermProcessor(terms_file, stopwords_filename, stemmer_type); - return [&queries, - tokenizer = std::shared_ptr(std::move(tokenizer)), - term_processor = std::move(term_processor)](std::string const& query_line) { - queries.push_back(parse_query_terms(query_line, *tokenizer, term_processor)); - }; - } - return [&queries](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); - }; -} - -bool read_query(term_id_vec& ret, std::istream& is) { - ret.clear(); - std::string line; - if (!std::getline(is, line)) { - return false; - } - ret = parse_query_ids(line).terms; - return true; -} - -void remove_duplicate_terms(term_id_vec& terms) { - std::sort(terms.begin(), terms.end()); - terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); -} - -term_freq_vec query_freqs(term_id_vec terms) { - term_freq_vec query_term_freqs; - std::sort(terms.begin(), terms.end()); - // count query term frequencies - for (size_t i = 0; i < terms.size(); ++i) { - if (i == 0 || terms[i] != terms[i - 1]) { - query_term_freqs.emplace_back(terms[i], 1); - } else { - query_term_freqs.back().second += 1; - } - } - return query_term_freqs; -} - -} // namespace pisa diff --git a/src/query/query_parser.cpp b/src/query/query_parser.cpp index 7ecb7048..7420d69d 100644 --- a/src/query/query_parser.cpp +++ b/src/query/query_parser.cpp @@ -1,21 +1,38 @@ +#include +#include + +#include "query.hpp" #include "query/query_parser.hpp" -#include "query/queries.hpp" +#include "string.hpp" +#include "term_map.hpp" namespace pisa { +[[nodiscard]] auto parse_term_id(std::string const& token) -> TermId { + std::size_t idx; + TermId tid = std::stol(token, &idx); + if (idx < token.size()) { + throw std::invalid_argument("invalid term ID: " + token); + } + return tid; +} + QueryParser::QueryParser(TextAnalyzer analyzer, std::unique_ptr term_map) : m_analyzer(std::move(analyzer)), m_term_map(std::move(term_map)) {} +QueryParser::QueryParser(TextAnalyzer analyzer) + : m_analyzer(std::move(analyzer)), m_term_map(std::make_unique()) {} + auto QueryParser::parse(std::string_view query) -> Query { - auto [id, raw_query] = split_query_at_colon(query); + auto [qid, raw_query] = split_at_colon(query); auto tokens = m_analyzer.analyze(raw_query); - std::vector query_ids; + std::vector term_ids; for (auto token: *tokens) { - if (auto id = (*m_term_map)(token); id) { - query_ids.push_back(*id); + if (auto tid = (*m_term_map)(token); tid) { + term_ids.push_back(*tid); } } - return {std::move(id), std::move(query_ids), {}}; + return Query(qid ? std::optional(*qid) : std::nullopt, term_ids); } auto QueryParser::parse(std::string const& query) -> Query { diff --git a/src/query/term_processor.cpp b/src/query/term_processor.cpp deleted file mode 100644 index dd64639a..00000000 --- a/src/query/term_processor.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "query/term_processor.hpp" - -#include -#include -#include - -namespace pisa { - -auto term_transformer_builder(std::optional const& type) -> TermTransformerBuilder { - if (not type) { - return [] { - return [](std::string&& term) -> std::string { - boost::algorithm::to_lower(term); - return std::move(term); - }; - }; - } - if (*type == "porter2") { - return [] { - return [](std::string&& term) -> std::string { - boost::algorithm::to_lower(term); - return porter2::Stemmer{}.stem(term); - }; - }; - } - if (*type == "krovetz") { - return []() { - return [kstemmer = std::make_shared()](std::string&& term - ) mutable -> std::string { - boost::algorithm::to_lower(term); - return kstemmer->kstem_stemmer(term); - }; - }; - } - throw std::invalid_argument(fmt::format("Unknown stemmer type: {}", *type)); -}; - -} // namespace pisa diff --git a/src/string.cpp b/src/string.cpp new file mode 100644 index 00000000..55e5a5e2 --- /dev/null +++ b/src/string.cpp @@ -0,0 +1,34 @@ +// Copyright 2024 PISA Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "pisa/string.hpp" + +namespace pisa { + +auto split_at_colon(std::string_view str) + -> std::pair, std::string_view> { + auto colon = std::find(str.begin(), str.end(), ':'); + std::optional id; + if (colon != str.end()) { + id = std::string_view(str.begin(), colon); + } + auto pos = colon == str.end() ? str.begin() : std::next(colon); + auto raw_query = std::string_view(&*pos, std::distance(pos, str.end())); + return {id, raw_query}; +} + +} // namespace pisa diff --git a/src/token_filter.cpp b/src/token_filter.cpp index 278d286a..94f2d9f6 100644 --- a/src/token_filter.cpp +++ b/src/token_filter.cpp @@ -1,8 +1,7 @@ #include "pisa/token_filter.hpp" -#include - #include +#include namespace pisa { @@ -69,4 +68,14 @@ auto StopWordRemover::filter(CowString input) const -> std::unique_ptr std::unique_ptr { + if (name == "porter2") { + return std::make_unique(); + } + if (name == "krovetz") { + return std::make_unique(); + } + throw std::domain_error(fmt::format("invalid stemmer name: %s", name)); +} + } // namespace pisa diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index ecbaa8a9..d5df3bca 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -1,16 +1,18 @@ +#include "query/query_parser.hpp" +#include "term_map.hpp" #include #define CATCH_CONFIG_MAIN #include "catch2/catch.hpp" +#include #include #include -#include "test_common.hpp" - +#include "binary_collection.hpp" #include "cursor/block_max_scored_cursor.hpp" #include "cursor/max_scored_cursor.hpp" -#include "cursor/scored_cursor.hpp" #include "index_types.hpp" +#include "io.hpp" #include "pisa_config.hpp" #include "query/algorithm.hpp" #include "wand_data.hpp" @@ -48,10 +50,13 @@ struct IndexData { ); } builder.build(index); - term_id_vec q; + std::vector q; + QueryParser parser( + TextAnalyzer(std::make_unique()), std::make_unique() + ); std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); + queries.push_back(parser.parse(query_line)); }; io::for_each_line(qfile, push_query); } diff --git a/test/test_cursors.cpp b/test/test_cursors.cpp index 761a6879..d1713de0 100644 --- a/test/test_cursors.cpp +++ b/test/test_cursors.cpp @@ -2,8 +2,6 @@ #include -#include "pisa/cursor/block_max_scored_cursor.hpp" -#include "pisa/cursor/cursor.hpp" #include "pisa/cursor/max_scored_cursor.hpp" #include "pisa/cursor/scored_cursor.hpp" #include "pisa/scorer/quantized.hpp" @@ -36,7 +34,7 @@ TEST_CASE("TODO") { }; InMemoryWand wand{{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, 10}; quantized scorer(wand); - Query query{"Q1", {0, 1, 1, 2}, {}}; + Query query("Q1", std::vector{0, 1, 1, 2}); auto collect_scores = [&index](auto&& cursor) { std::vector scores; diff --git a/test/test_forward_index_builder.cpp b/test/test_forward_index_builder.cpp index 78e82291..7baf79d0 100644 --- a/test/test_forward_index_builder.cpp +++ b/test/test_forward_index_builder.cpp @@ -1,7 +1,6 @@ #define CATCH_CONFIG_MAIN #include -#include #include #include @@ -11,8 +10,9 @@ #include "binary_collection.hpp" #include "filesystem.hpp" #include "forward_index_builder.hpp" +#include "io.hpp" #include "parser.hpp" -#include "parsing/html.hpp" +#include "payload_vector.hpp" #include "pisa_config.hpp" #include "temporary_directory.hpp" #include "text_analyzer.hpp" diff --git a/test/test_intersection.cpp b/test/test_intersection.cpp index 17ffc38a..cd77adc6 100644 --- a/test/test_intersection.cpp +++ b/test/test_intersection.cpp @@ -13,25 +13,25 @@ using namespace pisa::intersection; TEST_CASE("filter query", "[intersection][unit]") { GIVEN("Four-term query") { - Query query{ + Query query( "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights - }; + std::vector{6, 1, 5}, // terms + std::vector{0.1, 0.4, 1.0} // weights + ); auto [mask, expected] = GENERATE(table({ - {0b001, Query{"Q1", {6}, {0.1}}}, - {0b010, Query{"Q1", {1}, {0.4}}}, - {0b100, Query{"Q1", {5}, {1.0}}}, - {0b011, Query{"Q1", {6, 1}, {0.1, 0.4}}}, - {0b101, Query{"Q1", {6, 5}, {0.1, 1.0}}}, - {0b110, Query{"Q1", {1, 5}, {0.4, 1.0}}}, - {0b111, Query{"Q1", {6, 1, 5}, {0.1, 0.4, 1.0}}}, + {0b001, Query{"Q1", std::vector{6}, std::vector{0.1}}}, + {0b010, Query{"Q1", std::vector{1}, std::vector{0.4}}}, + {0b100, Query{"Q1", std::vector{5}, std::vector{1.0}}}, + {0b011, Query{"Q1", std::vector{6, 1}, std::vector{0.1, 0.4}}}, + {0b101, Query{"Q1", std::vector{6, 5}, std::vector{0.1, 1.0}}}, + {0b110, Query{"Q1", std::vector{1, 5}, std::vector{0.4, 1.0}}}, + {0b111, + Query{"Q1", std::vector{6, 1, 5}, std::vector{0.1, 0.4, 1.0}}}, })); WHEN("Filtered with mask " << mask) { auto actual = filter(query, mask); - CHECK(actual.id == expected.id); - CHECK(actual.terms == expected.terms); - CHECK(actual.term_weights == expected.term_weights); + CHECK(actual.id() == expected.id()); + CHECK(actual.terms() == expected.terms()); } } } @@ -118,8 +118,8 @@ TEST_CASE("compute intersection", "[intersection][unit]") { Query query{ "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights + std::vector{6, 1, 5}, // terms + std::vector{0.1, 0.4, 1.0} // weights }; auto [mask, len, max] = GENERATE(table({ {0b001, 3, 1.84583F}, @@ -144,8 +144,8 @@ TEST_CASE("for_all_subsets", "[intersection][unit]") { auto accumulate = [&](Query const&, Mask const& mask) { masks.push_back(mask); }; Query query{ "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights + std::vector{6, 1, 5}, // terms + std::vector{0.1, 0.4, 1.0} // weights }; WHEN("Executed with limit 0") { for_all_subsets(query, 0, accumulate); diff --git a/test/test_partition_fwd_index.cpp b/test/test_partition_fwd_index.cpp index a5ba7fff..ab176806 100644 --- a/test/test_partition_fwd_index.cpp +++ b/test/test_partition_fwd_index.cpp @@ -16,11 +16,9 @@ #include #include -#include "binary_freq_collection.hpp" -#include "filesystem.hpp" +#include "binary_collection.hpp" #include "forward_index_builder.hpp" #include "invert.hpp" -#include "parser.hpp" #include "payload_vector.hpp" #include "pisa_config.hpp" #include "sharding.hpp" diff --git a/test/test_queries.cpp b/test/test_queries.cpp deleted file mode 100644 index 9a521e9a..00000000 --- a/test/test_queries.cpp +++ /dev/null @@ -1,160 +0,0 @@ -#define CATCH_CONFIG_MAIN - -#include - -#include "query/algorithm.hpp" -#include "temporary_directory.hpp" - -using namespace pisa; - -TEST_CASE("Parse query term ids without query id") { - auto raw_query = "1 2\t3 4"; - auto q = parse_query_ids(raw_query); - REQUIRE(q.id.has_value() == false); - REQUIRE(q.terms == std::vector{1, 2, 3, 4}); -} - -TEST_CASE("Parse query term ids with query id") { - auto raw_query = "1: 1\t2 3\t4"; - auto q = parse_query_ids(raw_query); - REQUIRE(q.id == "1"); - REQUIRE(q.terms == std::vector{1, 2, 3, 4}); -} - -TEST_CASE("Compute parsing function") { - pisa::TemporaryDirectory tmpdir; - - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"a", "account", "he", "she", "usa", "world"}) - ) - .to_file(lexfile.string()); - auto stopwords_filename = tmpdir.path() / "stop"; - { - std::ofstream os(stopwords_filename.string()); - os << "a\nthe\n"; - } - - std::vector queries; - - WHEN("No stopwords, terms, or stemmer") { - // Note we don't need a tokenizer because ID parsing does not use it - auto parse = resolve_query_parser(queries, nullptr, std::nullopt, std::nullopt, std::nullopt); - THEN("Parse query IDs") { - parse("1:0 2 4"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{0, 2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } - WHEN("With terms and stopwords. No stemmer") { - auto parse = resolve_query_parser( - queries, - std::make_unique(), - lexfile.string(), - stopwords_filename.string(), - std::nullopt - ); - THEN("Parse query IDs") { - parse("1:a he usa"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } - WHEN("With terms, stopwords, and stemmer") { - auto parse = resolve_query_parser( - queries, - std::make_unique(), - lexfile.string(), - stopwords_filename.string(), - "porter2" - ); - THEN("Parse query IDs") { - parse("1:a he usa"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } - WHEN("Parser with whitespace tokenizer") { - auto parse = resolve_query_parser( - queries, std::make_unique(), lexfile.string(), std::nullopt, std::nullopt - ); - THEN("Parses usa's as usa's (and does not find it in lexicon)") { - parse("1:a he usa's"); - REQUIRE(queries[0].terms == std::vector{0, 2}); - } - } - WHEN("Parser with English tokenizer") { - auto parse = resolve_query_parser( - queries, std::make_unique(), lexfile.string(), std::nullopt, std::nullopt - ); - THEN("Parses usa's as usa (and finds it in lexicon)") { - parse("1:a he usa's"); - REQUIRE(queries[0].terms == std::vector{0, 2, 4}); - } - } -} - -TEST_CASE("Load stopwords in term processor with all stopwords present in the lexicon") { - pisa::TemporaryDirectory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"a", "account", "he", "she", "usa", "world"}) - ) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "a\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt - ); - REQUIRE(tprocessor.get_stopwords() == std::vector{0, 2, 3}); -} - -TEST_CASE("Load stopwords in term processor with some stopwords not present in the lexicon") { - pisa::TemporaryDirectory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"account", "coffee", "he", "she", "usa", "world"}) - ) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "\nis\nto\na\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt - ); - REQUIRE(tprocessor.get_stopwords() == std::vector{2, 3}); -} - -TEST_CASE("Check if term is stopword") { - pisa::TemporaryDirectory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"account", "coffee", "he", "she", "usa", "world"}) - ) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "\nis\nto\na\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt - ); - REQUIRE(!tprocessor.is_stopword(0)); - REQUIRE(!tprocessor.is_stopword(1)); - REQUIRE(tprocessor.is_stopword(2)); - REQUIRE(tprocessor.is_stopword(3)); - REQUIRE(!tprocessor.is_stopword(4)); - REQUIRE(!tprocessor.is_stopword(5)); -} diff --git a/test/test_query_stemmer.cpp b/test/test_query_stemmer.cpp deleted file mode 100644 index dffbc0c1..00000000 --- a/test/test_query_stemmer.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch2/catch.hpp" -#include "parser.hpp" -#include "query/query_stemmer.hpp" -#include - -using namespace pisa; - -TEST_CASE("Stem query", "[stemming][unit]") { - auto [input, expected] = GENERATE(table( - {{"1:playing cards", "1:play card"}, - {"playing cards", "play card"}, - {"play card", "play card"}, - {"1:this:that", "1:this that"}} - )); - QueryStemmer query_stemmer("porter2"); - GIVEN("Input: " << input) { - CHECK(query_stemmer(input) == expected); - } -} diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index 85194e87..d4154f80 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -1,3 +1,4 @@ +#include "query/query_parser.hpp" #include "type_safe.hpp" #define CATCH_CONFIG_MAIN @@ -10,6 +11,7 @@ #include "cursor/block_max_scored_cursor.hpp" #include "cursor/scored_cursor.hpp" #include "index_types.hpp" +#include "io.hpp" #include "pisa_config.hpp" #include "query/algorithm/block_max_maxscore_query.hpp" #include "query/algorithm/block_max_ranked_and_query.hpp" @@ -54,10 +56,13 @@ struct IndexData { } builder.build(index); - term_id_vec q; + std::vector q; + QueryParser parser( + TextAnalyzer(std::make_unique()), std::make_unique() + ); std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); + queries.push_back(parser.parse(query_line)); }; io::for_each_line(qfile, push_query); diff --git a/test/test_taily_stats.cpp b/test/test_taily_stats.cpp index 1776a8ee..83da32e5 100644 --- a/test/test_taily_stats.cpp +++ b/test/test_taily_stats.cpp @@ -11,6 +11,7 @@ #include "binary_freq_collection.hpp" #include "io.hpp" #include "memory_source.hpp" +#include "query.hpp" #include "scorer/scorer.hpp" #include "taily_stats.hpp" #include "temporary_directory.hpp" @@ -149,7 +150,8 @@ TEST_CASE("Write Taily feature stats", "[taily][unit]") { REQUIRE_THROWS_AS(stats.term_stats(3), std::out_of_range); - auto query_stats = stats.query_stats(pisa::Query{{}, {0, 1, 2}, {}}); + auto query_stats = + stats.query_stats(pisa::Query{std::nullopt, std::vector{0, 1, 2}}); REQUIRE(query_stats.collection_size == 10); diff --git a/test/test_tokenizer.cpp b/test/test_tokenizer.cpp index adda88e8..17ae3bc9 100644 --- a/test/test_tokenizer.cpp +++ b/test/test_tokenizer.cpp @@ -1,15 +1,17 @@ #define CATCH_CONFIG_MAIN #include -#include #include #include #include #include "payload_vector.hpp" -#include "query/queries.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" #include "temporary_directory.hpp" +#include "term_map.hpp" +#include "token_filter.hpp" #include "tokenizer.hpp" using namespace pisa; @@ -88,18 +90,20 @@ TEST_CASE("Parse query terms to ids") { .to_file(lexfile.string()); auto [query, id, parsed] = - GENERATE(table, std::vector>( - {{"17:obama family tree", "17", {1, 3}}, - {"obama family tree", std::nullopt, {1, 3}}, - {"obama, family, trees", std::nullopt, {1, 3}}, - {"obama + family + tree", std::nullopt, {1, 3}}, - {"lol's", std::nullopt, {0}}, - {"U.S.A.!?", std::nullopt, {4}}} + GENERATE(table, std::vector>( + {{"17:obama family tree", "17", {{1, 1.0}, {3, 1.0}}}, + {"obama family tree", std::nullopt, {{1, 1.0}, {3, 1.0}}}, + {"obama, family, trees", std::nullopt, {{1, 1.0}, {3, 1.0}}}, + {"obama + family + tree", std::nullopt, {{1, 1.0}, {3, 1.0}}}, + {"lol's", std::nullopt, {{0, 1.0}}}, + {"U.S.A.!?", std::nullopt, {{4, 1.0}}}} )); CAPTURE(query); - TermProcessor term_processor(std::make_optional(lexfile.string()), std::nullopt, "krovetz"); - EnglishTokenizer tokenizer; - auto q = parse_query_terms(query, tokenizer, term_processor); - REQUIRE(q.id == id); - REQUIRE(q.terms == parsed); + + auto analyzer = TextAnalyzer(std::make_unique()); + analyzer.emplace_token_filter(); + QueryParser parser(std::move(analyzer), std::make_unique(lexfile.string())); + auto q = parser.parse(query); + REQUIRE(q.id() == id); + REQUIRE(q.terms() == parsed); } diff --git a/tools/app.hpp b/tools/app.hpp index 5c6342dd..d241db88 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -14,10 +14,10 @@ #include #include "io.hpp" +#include "pisa/query.hpp" #include "pisa/query/query_parser.hpp" #include "pisa/term_map.hpp" #include "pisa/text_analyzer.hpp" -#include "query/queries.hpp" #include "scorer/scorer.hpp" #include "sharding.hpp" #include "tokenizer.hpp" @@ -119,7 +119,7 @@ namespace arg { } [[nodiscard]] auto queries() const -> std::vector<::pisa::Query> { - std::vector<::pisa::Query> q; + std::vector<::pisa::Query> qs; std::unique_ptr term_map = [this]() -> std::unique_ptr { if (this->m_term_lexicon) { return std::make_unique(*this->m_term_lexicon); @@ -127,14 +127,14 @@ namespace arg { return std::make_unique(); }(); QueryParser parser(text_analyzer(), std::move(term_map)); - auto parse_query = [&q, &parser](auto&& line) { q.push_back(parser.parse(line)); }; + auto parse_query = [&qs, &parser](auto&& line) { qs.push_back(parser.parse(line)); }; if (m_query_file) { std::ifstream is(*m_query_file); io::for_each_line(is, parse_query); } else { io::for_each_line(std::cin, parse_query); } - return q; + return qs; } [[nodiscard]] auto k() const -> int { return m_k; } diff --git a/tools/compute_intersection.cpp b/tools/compute_intersection.cpp index a995a5c7..732fa46c 100644 --- a/tools/compute_intersection.cpp +++ b/tools/compute_intersection.cpp @@ -2,7 +2,6 @@ #include #include -#include "mappable/mapper.hpp" #include #include #include @@ -12,6 +11,7 @@ #include "app.hpp" #include "index_types.hpp" #include "intersection.hpp" +#include "mappable/mapper.hpp" #include "wand_data.hpp" #include "wand_data_raw.hpp" @@ -50,7 +50,7 @@ void intersect( auto intersection = Intersection::compute(index, wdata, query, mask); std::cout << fmt::format( "{}\t{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), + query.id() ? *query.id() : std::to_string(qid), mask.to_ulong(), intersection.length, intersection.max_score @@ -64,7 +64,7 @@ void intersect( auto intersection = Intersection::compute(index, wdata, query); std::cout << fmt::format( "{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), + query.id() ? *query.id() : std::to_string(qid), intersection.length, intersection.max_score ); @@ -105,7 +105,7 @@ int main(int argc, const char** argv) { auto queries = app.queries(); auto filtered_queries = ranges::views::filter(queries, [&](auto&& query) { - auto size = query.terms.size(); + auto size = query.terms().size(); return size < min_query_len || size > max_query_len; }); diff --git a/tools/count_postings.cpp b/tools/count_postings.cpp index 0f234eef..d0c1e63f 100644 --- a/tools/count_postings.cpp +++ b/tools/count_postings.cpp @@ -9,7 +9,6 @@ #include #include "app.hpp" -#include "binary_collection.hpp" #include "index_types.hpp" using namespace pisa; @@ -17,7 +16,7 @@ using namespace pisa; template void extract( std::string const& index_filename, - std::vector const& queries, + std::vector const& queries, std::string const& separator, bool sum, bool print_qid @@ -27,18 +26,18 @@ void extract( if (sum) { return std::function([&](auto const& query) { auto count = std::accumulate( - query.terms.begin(), - query.terms.end(), + query.terms().begin(), + query.terms().end(), 0, - [&](auto s, auto term_id) { return s + index[term_id].size(); } + [&](auto s, auto term) { return s + index[term.id].size(); } ); std::cout << count << '\n'; }); } return std::function([&](auto const& query) { std::cout << boost::algorithm::join( - query.terms | boost::adaptors::transformed([&index](auto term_id) { - return std::to_string(index[term_id].size()); + query.terms() | boost::adaptors::transformed([&index](auto term) { + return std::to_string(index[term.id].size()); }), separator ); @@ -46,8 +45,8 @@ void extract( }); }(); for (auto const& query: queries) { - if (print_qid && query.id) { - std::cout << *query.id << ":"; + if (print_qid && query.id()) { + std::cout << *query.id() << ":"; } body(query); } diff --git a/tools/evaluate_queries.cpp b/tools/evaluate_queries.cpp index 02e6f575..ce1e8e56 100644 --- a/tools/evaluate_queries.cpp +++ b/tools/evaluate_queries.cpp @@ -160,7 +160,7 @@ void evaluate_queries( for (size_t query_idx = 0; query_idx < raw_results.size(); ++query_idx) { auto results = raw_results[query_idx]; - auto qid = queries[query_idx].id; + auto qid = queries[query_idx].id(); for (auto&& [rank, result]: enumerate(results)) { std::cout << fmt::format( "{} {} {} {} {} {}\n", diff --git a/tools/extract_maxscores.cpp b/tools/extract_maxscores.cpp index ef71b137..a63bf06c 100644 --- a/tools/extract_maxscores.cpp +++ b/tools/extract_maxscores.cpp @@ -1,6 +1,4 @@ -#include #include -#include #include #include @@ -19,18 +17,18 @@ using namespace pisa; template void extract( std::string const& wand_data_path, - std::vector const& queries, + std::vector const& queries, std::string const& separator, bool print_query_id ) { Wand wdata(MemorySource::mapped_file(wand_data_path)); for (auto const& query: queries) { - if (print_query_id and query.id) { - std::cout << *(query.id) << ":"; + if (print_query_id and query.id()) { + std::cout << *(query.id()) << ":"; } std::cout << boost::algorithm::join( - query.terms | boost::adaptors::transformed([&wdata](auto term_id) { - return std::to_string(wdata.max_term_weight(term_id)); + query.terms() | boost::adaptors::transformed([&wdata](auto term) { + return std::to_string(wdata.max_term_weight(term.id)); }), separator ); diff --git a/tools/kth_threshold.cpp b/tools/kth_threshold.cpp index 9360f81a..e5685834 100644 --- a/tools/kth_threshold.cpp +++ b/tools/kth_threshold.cpp @@ -22,8 +22,6 @@ #include "query/algorithm/wand_query.hpp" #include "scorer/scorer.hpp" -#include "CLI/CLI.hpp" - using namespace pisa; std::set parse_tuple(std::string const& line, size_t k) { @@ -110,22 +108,20 @@ void kt_thresholds( for (auto const& query: queries) { float threshold = 0; - auto terms = query.terms; + auto terms = query.terms(); topk_queue topk(k); wand_query wand_q(topk); for (auto&& term: terms) { - Query query; - query.terms.push_back(term); + Query query{std::nullopt, std::array{term.id}}; wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); threshold = std::max(threshold, topk.size() == k ? topk.true_threshold() : 0.0F); topk.clear(); } for (size_t i = 0; i < terms.size(); ++i) { for (size_t j = i + 1; j < terms.size(); ++j) { - if (pairs_set.count({terms[i], terms[j]}) > 0 or all_pairs) { - Query query; - query.terms = {terms[i], terms[j]}; + if (pairs_set.count({terms[i].id, terms[j].id}) > 0 or all_pairs) { + Query query{std::nullopt, std::array{terms[i].id, terms[j].id}}; wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); threshold = std::max(threshold, topk.size() == k ? topk.true_threshold() : 0.0F); topk.clear(); @@ -135,9 +131,12 @@ void kt_thresholds( for (size_t i = 0; i < terms.size(); ++i) { for (size_t j = i + 1; j < terms.size(); ++j) { for (size_t s = j + 1; s < terms.size(); ++s) { - if (triples_set.count({terms[i], terms[j], terms[s]}) > 0 or all_triples) { - Query query; - query.terms = {terms[i], terms[j], terms[s]}; + if (triples_set.count({terms[i].id, terms[j].id, terms[s].id}) > 0 + or all_triples) { + Query query{ + std::nullopt, + std::array{terms[i].id, terms[j].id, terms[s].id} + }; wand_q( make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs() ); diff --git a/tools/map_queries.cpp b/tools/map_queries.cpp index 80376d9a..722a7b89 100644 --- a/tools/map_queries.cpp +++ b/tools/map_queries.cpp @@ -3,7 +3,7 @@ #include #include "app.hpp" -#include "query/queries.hpp" +#include "query.hpp" #include "spdlog/sinks/stdout_color_sinks.h" #include "spdlog/spdlog.h" @@ -23,11 +23,11 @@ int main(int argc, const char** argv) { using boost::adaptors::transformed; using boost::algorithm::join; for (auto&& q: app.queries()) { - if (app.print_query_id() and q.id) { - std::cout << *(q.id) << ":"; + if (app.print_query_id() and q.id()) { + std::cout << *(q.id()) << ":"; } - std::cout - << join(q.terms | transformed([](auto d) { return std::to_string(d); }), app.separator()) - << '\n'; + std::cout << join( + q.terms() | transformed([](auto d) { return std::to_string(d.id); }), app.separator() + ) << '\n'; } } diff --git a/tools/parse_collection.cpp b/tools/parse_collection.cpp index 0dce1204..4d0c1378 100644 --- a/tools/parse_collection.cpp +++ b/tools/parse_collection.cpp @@ -9,7 +9,6 @@ #include "app.hpp" #include "forward_index_builder.hpp" #include "parser.hpp" -#include "query/term_processor.hpp" using namespace pisa; diff --git a/tools/profile_queries.cpp b/tools/profile_queries.cpp index 78b2c0a0..4d5b061f 100644 --- a/tools/profile_queries.cpp +++ b/tools/profile_queries.cpp @@ -4,6 +4,7 @@ #include "boost/algorithm/string/classification.hpp" #include "boost/algorithm/string/split.hpp" +#include "query/query_parser.hpp" #include "spdlog/spdlog.h" #include "mio/mmap.hpp" @@ -19,6 +20,7 @@ #include "query/algorithm/ranked_and_query.hpp" #include "query/algorithm/wand_query.hpp" #include "scorer/scorer.hpp" +#include "tokenizer.hpp" #include "wand_data.hpp" #include "wand_data_raw.hpp" @@ -164,20 +166,21 @@ int main(int argc, const char** argv) { } std::vector queries; - term_id_vec q; + std::string line; + QueryParser parser(TextAnalyzer(std::make_unique())); if (std::string(argv[args]) == "--file") { args++; args++; std::filebuf fb; if (fb.open(argv[args], std::ios::in) != nullptr) { std::istream is(&fb); - while (read_query(q, is)) { - queries.push_back({std::nullopt, q, {}}); + while (std::getline(is, line)) { + queries.push_back(parser.parse(line)); } } } else { - while (read_query(q)) { - queries.push_back({std::nullopt, q, {}}); + while (std::getline(std::cin, line)) { + queries.push_back(parser.parse(line)); } } diff --git a/tools/queries.cpp b/tools/queries.cpp index 7e780db2..822b4593 100644 --- a/tools/queries.cpp +++ b/tools/queries.cpp @@ -35,6 +35,7 @@ #include "timer.hpp" #include "topk_queue.hpp" #include "type_alias.hpp" +#include "util/do_not_optimize_away.hpp" #include "util/util.hpp" #include "wand_data.hpp" #include "wand_data_compressed.hpp" @@ -62,7 +63,7 @@ void extract_times( ).count(); }); auto mean = std::accumulate(times.begin(), times.end(), std::size_t{0}, std::plus<>()) / runs; - os << fmt::format("{}\t{}\n", query.id.value_or(std::to_string(qid)), mean); + os << fmt::format("{}\t{}\n", query.id().value_or(std::to_string(qid)), mean); } } @@ -143,9 +144,9 @@ void perftest( IndexType index(MemorySource::mapped_file(index_filename)); spdlog::info("Warming up posting lists"); - std::unordered_set warmed_up; + std::unordered_set warmed_up; for (auto const& q: queries) { - for (auto t: q.terms) { + for (auto [t, _]: q.terms()) { if (!warmed_up.count(t)) { index.warmup(t); warmed_up.insert(t); diff --git a/tools/selective_queries.cpp b/tools/selective_queries.cpp index 3df3ba5c..137767bc 100644 --- a/tools/selective_queries.cpp +++ b/tools/selective_queries.cpp @@ -3,12 +3,11 @@ #include "mappable/mapper.hpp" #include "mio/mmap.hpp" -#include "CLI/CLI.hpp" #include "app.hpp" #include "cursor/cursor.hpp" #include "index_types.hpp" -#include "query/algorithm.hpp" -#include "query/queries.hpp" +#include "query/algorithm/and_query.hpp" +#include "query/algorithm/or_query.hpp" #include #include @@ -34,7 +33,7 @@ void selective_queries( double selectiveness = double(and_results) / double(or_results); if (selectiveness < 0.005) { std::cout - << join(query.terms | transformed([](auto d) { return std::to_string(d); }), " ") + << join(query.terms() | transformed([](auto d) { return std::to_string(d.id); }), " ") << '\n'; } } diff --git a/tools/shards.cpp b/tools/shards.cpp index b81d3be3..0d2d05a3 100644 --- a/tools/shards.cpp +++ b/tools/shards.cpp @@ -154,8 +154,8 @@ int main(int argc, char** argv) { if (taily_rank->parsed()) { auto shards = resolve_shards(taily_rank_args.shard_stats()); pisa::VecMap shard_stats; - std::vector<::pisa::Query> global_queries; - pisa::VecMap> shard_queries; + std::vector global_queries; + pisa::VecMap> shard_queries; for (auto shard: shards) { auto shard_args = taily_rank_args; shard_args.apply_shard(shard); diff --git a/tools/stem_queries.cpp b/tools/stem_queries.cpp index d3afeced..6d8c85c5 100644 --- a/tools/stem_queries.cpp +++ b/tools/stem_queries.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -6,12 +7,15 @@ #include "app.hpp" #include "io.hpp" -#include "pisa/query/query_stemmer.hpp" +#include "pisa/text_analyzer.hpp" +#include "pisa/token_filter.hpp" +#include "query/query_parser.hpp" +#include "tokenizer.hpp" int main(int argc, char const* argv[]) { std::string input_filename; std::string output_filename; - std::optional stemmer; + std::string stemmer; pisa::App app{"A tool for stemming PISA queries."}; app.add_option("-i,--input", input_filename, "Query input file")->required(); @@ -28,9 +32,21 @@ int main(int argc, char const* argv[]) { auto input_file = std::ifstream(input_filename); try { - pisa::QueryStemmer query_stemmer(stemmer); + auto analyzer = pisa::TextAnalyzer(std::make_unique()); + analyzer.add_token_filter(pisa::stemmer_from_name(stemmer)); + pisa::QueryParser parser(std::move(analyzer)); pisa::io::for_each_line(input_file, [&](std::string const& line) { - output_file << query_stemmer(line) << "\n"; + auto query = parser.parse(line); + if (query.id()) { + output_file << *query.id() << ":"; + } + auto const& terms = query.terms(); + if (!terms.empty()) { + output_file << terms.front().id; + for (auto pos = std::next(terms.begin()); pos != terms.end(); std::advance(pos, 1)) { + output_file << ' ' << pos->id; + } + } }); } catch (const std::invalid_argument& ex) { spdlog::error(ex.what()); diff --git a/tools/taily_stats.hpp b/tools/taily_stats.hpp index e559e110..740545ba 100644 --- a/tools/taily_stats.hpp +++ b/tools/taily_stats.hpp @@ -4,7 +4,9 @@ #include "app.hpp" #include "pisa/taily_stats.hpp" -#include "pisa/util/compiler_attribute.hpp" +#include "pisa/wand_data.hpp" +#include "pisa/wand_data_compressed.hpp" +#include "pisa/wand_data_raw.hpp" namespace pisa {