Skip to content

Commit

Permalink
Query refactoring
Browse files Browse the repository at this point in the history
Weights are now stored together with term IDs and resolved at
construction time according to one of the policies. In our tools, we use
the default policy that removes duplicates and sets the weight to the
number of occurrences of the term in a query. Other policies are, for
the time being, only available programmatically via the library API.

Some legacy code used to parse and process queries has been removed in
favor of the text analyzer and the new query parser.

Because weights are resolved when a query object is created, I also
refactored creating the cursors: now the weight is simply taken from the
query.
  • Loading branch information
elshize committed Dec 30, 2023
1 parent 35330c4 commit 088a40f
Show file tree
Hide file tree
Showing 58 changed files with 625 additions and 749 deletions.
41 changes: 12 additions & 29 deletions include/pisa/cursor/block_max_scored_cursor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <vector>

#include "cursor/max_scored_cursor.hpp"
#include "query/queries.hpp"
#include "query.hpp"
#include "scorer/index_scorer.hpp"
#include "util/compiler_attribute.hpp"

Expand All @@ -30,7 +30,7 @@ class BlockMaxScoredCursor: public MaxScoredCursor<Cursor> {
~BlockMaxScoredCursor() = default;

[[nodiscard]] PISA_ALWAYSINLINE auto block_max_score() -> float {
return m_wdata.score() * this->query_weight();
return m_wdata.score() * this->weight();
}

[[nodiscard]] PISA_ALWAYSINLINE auto block_max_docid() -> std::uint32_t {
Expand All @@ -45,38 +45,21 @@ class BlockMaxScoredCursor: public MaxScoredCursor<Cursor> {

template <typename Index, typename WandType, typename Scorer>
[[nodiscard]] auto make_block_max_scored_cursors(
Index const& index, WandType const& wdata, Scorer const& scorer, Query query, bool weighted = false
Index const& index, WandType const& wdata, Scorer const& scorer, Query const& query, bool weighted = false
) {
auto terms = query.terms;
auto query_term_freqs = query_freqs(terms);

std::vector<BlockMaxScoredCursor<typename Index::document_enumerator, WandType>> cursors;
cursors.reserve(query_term_freqs.size());
cursors.reserve(query.terms().size());
std::transform(
query_term_freqs.begin(),
query_term_freqs.end(),
query.terms().begin(),
query.terms().end(),
std::back_inserter(cursors),
[&](auto&& term) {
auto term_weight = 1.0F;
auto term_id = term.first;
auto max_weight = wdata.max_term_weight(term_id);

if (weighted) {
term_weight = term.second;
max_weight = term_weight * max_weight;
return BlockMaxScoredCursor<typename Index::document_enumerator, WandType>(
index[term_id],
[scorer = scorer.term_scorer(term_id), weight = term_weight](
uint32_t doc, uint32_t freq
) { return weight * scorer(doc, freq); },
term_weight,
max_weight,
wdata.getenum(term_id)
);
}

[&](WeightedTerm const& term) {
return BlockMaxScoredCursor<typename Index::document_enumerator, WandType>(
index[term_id], scorer.term_scorer(term_id), term_weight, max_weight, wdata.getenum(term_id)
index[term.id],
scorer.term_scorer(term.id),
weighted ? term.weight : 1.0F,
wdata.max_term_weight(term.id),
wdata.getenum(term.id)
);
}
);
Expand Down
21 changes: 14 additions & 7 deletions include/pisa/cursor/cursor.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
#pragma once

#include "query/queries.hpp"
#include <vector>

#include "query.hpp"

namespace pisa {

/** Creates cursors for query terms.
*
* These are frequency-only cursors. If you want to calculate scores, use
* `make_scored_cursors`, `make_max_scored_cursors`, or `make_block_max_scored_cursors`.
*/
template <typename Index>
[[nodiscard]] auto make_cursors(Index const& index, Query query) {
auto terms = query.terms;
remove_duplicate_terms(terms);
using cursor = typename Index::document_enumerator;

std::vector<cursor> cursors;
cursors.reserve(terms.size());
std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto&& term) {
return index[term];
});
cursors.reserve(query.terms().size());
std::transform(
query.terms().begin(),
query.terms().end(),
std::back_inserter(cursors),
[&](auto&& term) { return index[term.id]; }
);

return cursors;
}
Expand Down
45 changes: 15 additions & 30 deletions include/pisa/cursor/max_scored_cursor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <vector>

#include "cursor/scored_cursor.hpp"
#include "query/queries.hpp"
#include "query.hpp"
#include "util/compiler_attribute.hpp"

namespace pisa {
Expand All @@ -13,54 +13,39 @@ class MaxScoredCursor: public ScoredCursor<Cursor> {
public:
using base_cursor_type = Cursor;

MaxScoredCursor(Cursor cursor, TermScorer term_scorer, float query_weight, float max_score)
: ScoredCursor<Cursor>(std::move(cursor), std::move(term_scorer), query_weight),
MaxScoredCursor(Cursor cursor, TermScorer term_scorer, float weight, float max_score)
: ScoredCursor<Cursor>(std::move(cursor), std::move(term_scorer), weight),
m_max_score(max_score) {}
MaxScoredCursor(MaxScoredCursor const&) = delete;
MaxScoredCursor(MaxScoredCursor&&) = default;
MaxScoredCursor& operator=(MaxScoredCursor const&) = delete;
MaxScoredCursor& operator=(MaxScoredCursor&&) = default;
~MaxScoredCursor() = default;

[[nodiscard]] PISA_ALWAYSINLINE auto max_score() const noexcept -> float { return m_max_score; }
[[nodiscard]] PISA_ALWAYSINLINE auto max_score() const noexcept -> float {
return this->weight() * m_max_score;
}

private:
float m_max_score;
};

template <typename Index, typename WandType, typename Scorer>
[[nodiscard]] auto make_max_scored_cursors(
Index const& index, WandType const& wdata, Scorer const& scorer, Query query, bool weighted = false
Index const& index, WandType const& wdata, Scorer const& scorer, Query const& query, bool weighted = false
) {
auto terms = query.terms;
auto query_term_freqs = query_freqs(terms);

std::vector<MaxScoredCursor<typename Index::document_enumerator>> cursors;
cursors.reserve(query_term_freqs.size());
cursors.reserve(query.terms().size());
std::transform(
query_term_freqs.begin(),
query_term_freqs.end(),
query.terms().begin(),
query.terms().end(),
std::back_inserter(cursors),
[&](auto&& term) {
auto term_weight = 1.0F;
auto term_id = term.first;
auto max_weight = wdata.max_term_weight(term_id);

if (weighted) {
term_weight = term.second;
max_weight = term_weight * max_weight;
return MaxScoredCursor<typename Index::document_enumerator>(
index[term_id],
[scorer = scorer.term_scorer(term_id), weight = term_weight](
uint32_t doc, uint32_t freq
) { return weight * scorer(doc, freq); },
term_weight,
max_weight
);
}

[&](WeightedTerm const& term) {
return MaxScoredCursor<typename Index::document_enumerator>(
index[term_id], scorer.term_scorer(term_id), term_weight, max_weight
index[term.id],
scorer.term_scorer(term.id),
weighted ? term.weight : 1.0F,
wdata.max_term_weight(term.id)
);
}
);
Expand Down
54 changes: 23 additions & 31 deletions include/pisa/cursor/scored_cursor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,37 @@

#include <vector>

#include "query/queries.hpp"
#include "query.hpp"
#include "scorer/index_scorer.hpp"
#include "util/compiler_attribute.hpp"

namespace pisa {

template <typename Scorer>
auto resolve_term_scorer(Scorer scorer, float weight) -> TermScorer {
if (weight == 1.0F) {
// Optimization: no multiplication necessary if weight is 1.0
return scorer;
}
return [scorer, weight](uint32_t doc, uint32_t freq) { return weight * scorer(doc, freq); };
}

template <typename Cursor>
class ScoredCursor {
public:
using base_cursor_type = Cursor;

ScoredCursor(Cursor cursor, TermScorer term_scorer, float query_weight)
ScoredCursor(Cursor cursor, TermScorer term_scorer, float weight)
: m_base_cursor(std::move(cursor)),
m_term_scorer(std::move(term_scorer)),
m_query_weight(query_weight) {}
m_weight(weight),
m_term_scorer(resolve_term_scorer(term_scorer, weight)) {}
ScoredCursor(ScoredCursor const&) = delete;
ScoredCursor(ScoredCursor&&) = default;
ScoredCursor& operator=(ScoredCursor const&) = delete;
ScoredCursor& operator=(ScoredCursor&&) = default;
~ScoredCursor() = default;

[[nodiscard]] PISA_ALWAYSINLINE auto query_weight() const noexcept -> float {
return m_query_weight;
}
[[nodiscard]] PISA_ALWAYSINLINE auto weight() const noexcept -> float { return m_weight; }
[[nodiscard]] PISA_ALWAYSINLINE auto docid() const -> std::uint32_t {
return m_base_cursor.docid();
}
Expand All @@ -37,38 +44,23 @@ class ScoredCursor {

private:
Cursor m_base_cursor;
float m_weight = 1.0;
TermScorer m_term_scorer;
float m_query_weight = 1.0;
};

template <typename Index, typename Scorer>
[[nodiscard]] auto
make_scored_cursors(Index const& index, Scorer const& scorer, Query query, bool weighted = false) {
auto terms = query.terms;
auto query_term_freqs = query_freqs(terms);

[[nodiscard]] auto make_scored_cursors(
Index const& index, Scorer const& scorer, Query const& query, bool weighted = false
) {
std::vector<ScoredCursor<typename Index::document_enumerator>> cursors;
cursors.reserve(query_term_freqs.size());
cursors.reserve(query.terms().size());
std::transform(
query_term_freqs.begin(),
query_term_freqs.end(),
query.terms().begin(),
query.terms().end(),
std::back_inserter(cursors),
[&](auto&& term) {
auto term_weight = 1.0F;
auto term_id = term.first;

if (weighted) {
term_weight = term.second;
return ScoredCursor<typename Index::document_enumerator>(
index[term_id],
[scorer = scorer.term_scorer(term_id), weight = term_weight](
uint32_t doc, uint32_t freq
) { return weight * scorer(doc, freq); },
term_weight
);
}
[&](WeightedTerm const& term) {
return ScoredCursor<typename Index::document_enumerator>(
index[term_id], scorer.term_scorer(term_id), term_weight
index[term.id], scorer.term_scorer(term.id), weighted ? term.weight : 1.0
);
}
);
Expand Down
3 changes: 0 additions & 3 deletions include/pisa/forward_index_builder.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#pragma once

#include <cctype>
#include <functional>
#include <optional>
#include <string>
#include <unordered_map>

#include "document_record.hpp"
#include "forward_index_builder.hpp"
#include "query/term_processor.hpp"
#include "text_analyzer.hpp"
#include "type_safe.hpp"

Expand Down
21 changes: 12 additions & 9 deletions include/pisa/intersection.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <bitset>
#include <cstddef>
#include <optional>
#include <string>
#include <vector>

#include "cursor/scored_cursor.hpp"
#include "query/algorithm/and_query.hpp"
#include "query/queries.hpp"
#include "scorer/scorer.hpp"

namespace pisa {
Expand All @@ -24,20 +24,23 @@ namespace intersection {

/// Returns a filtered copy of `query` containing only terms indicated by ones in the bit mask.
[[nodiscard]] inline auto filter(Query const& query, Mask const& mask) -> Query {
if (query.terms.size() > MAX_QUERY_LEN) {
if (query.terms().size() > MAX_QUERY_LEN) {
throw std::invalid_argument("Queries can be at most 2^32 terms long");
}
std::vector<std::uint32_t> terms;
std::vector<float> weights;
for (std::size_t bitpos = 0; bitpos < query.terms.size(); ++bitpos) {
for (std::size_t bitpos = 0; bitpos < query.terms().size(); ++bitpos) {
if (((1U << bitpos) & mask.to_ulong()) > 0) {
terms.push_back(query.terms.at(bitpos));
if (bitpos < query.term_weights.size()) {
weights.push_back(query.term_weights[bitpos]);
}
auto term = query.terms().at(bitpos);
terms.push_back(term.id);
weights.push_back(term.weight);
}
}
return Query{query.id, terms, weights};
return Query(
query.id().has_value() ? std::make_optional<std::string>(*query.id()) : std::nullopt,
terms,
weights
);
}
} // namespace intersection

Expand Down Expand Up @@ -80,7 +83,7 @@ inline auto Intersection::compute(
/// `Fn` takes `Query` and `Mask`.
template <typename Fn>
auto for_all_subsets(Query const& query, std::optional<std::uint8_t> max_term_count, Fn func) {
auto subset_count = 1U << query.terms.size();
auto subset_count = 1U << query.terms().size();
for (auto subset = 1U; subset < subset_count; ++subset) {
auto mask = intersection::Mask(subset);
if (!max_term_count || mask.count() <= *max_term_count) {
Expand Down
Loading

0 comments on commit 088a40f

Please sign in to comment.