Skip to content

Commit

Permalink
refactor MatchedQueriesPhase to return different processors depending…
Browse files Browse the repository at this point in the history
… on context

Signed-off-by: Dharin Shah <8616130+Dharin-shah@users.noreply.github.com>
  • Loading branch information
Dharin-shah committed Jan 31, 2024
1 parent 2bf0cc2 commit bb0c0d8
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@

/**
* Fetch phase of a search request, used to fetch the actual top matching documents to be returned to the client, identified
* after reducing all of the matches returned by the query phase
* after reducing all the matches returned by the query phase
*
* @opensearch.api
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.search.fetch.subphase;

import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -37,71 +6,69 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.search.fetch.FetchContext;
import org.opensearch.search.fetch.FetchSubPhase;
import org.opensearch.search.fetch.FetchSubPhaseProcessor;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Fetches queries that match the document during search phase
*
* @opensearch.internal
*/
public final class MatchedQueriesPhase implements FetchSubPhase {

@Override
public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOException {
Map<String, Query> namedQueries = new HashMap<>();
if (context.parsedQuery() != null) {
namedQueries.putAll(context.parsedQuery().namedFilters());
}
if (context.parsedPostFilter() != null) {
namedQueries.putAll(context.parsedPostFilter().namedFilters());
}
Map<String, Query> namedQueries = collectNamedQueries(context);
if (namedQueries.isEmpty()) {
return null;
}

Map<String, Weight> weights = prepareWeights(context, namedQueries);

return context.includeNamedQueriesScore() ? createScoringProcessor(weights) : createNonScoringProcessor(weights);
}

private Map<String, Query> collectNamedQueries(FetchContext context) {
Map<String, Query> namedQueries = new HashMap<>();
Optional.ofNullable(context.parsedQuery()).ifPresent(parsedQuery -> namedQueries.putAll(parsedQuery.namedFilters()));
Optional.ofNullable(context.parsedPostFilter()).ifPresent(parsedPostFilter -> namedQueries.putAll(parsedPostFilter.namedFilters()));
return namedQueries;
}

private Map<String, Weight> prepareWeights(FetchContext context, Map<String, Query> namedQueries) throws IOException {
Map<String, Weight> weights = new HashMap<>();
for (Map.Entry<String, Query> entry : namedQueries.entrySet()) {
weights.put(
entry.getKey(),
context.includeNamedQueriesScore()
? context.searcher().createWeight(context.searcher().rewrite(entry.getValue()), ScoreMode.COMPLETE, 1)
: context.searcher().createWeight(context.searcher().rewrite(entry.getValue()), ScoreMode.COMPLETE_NO_SCORES, 1)
);
ScoreMode scoreMode = context.includeNamedQueriesScore() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
weights.put(entry.getKey(), context.searcher().createWeight(context.searcher().rewrite(entry.getValue()), scoreMode, 1));
}
return new FetchSubPhaseProcessor() {
return weights;
}

final Map<String, Scorer> matchingIterators = new HashMap<>();
private FetchSubPhaseProcessor createScoringProcessor(Map<String, Weight> weights) {
return new FetchSubPhaseProcessor() {
final Map<String, Scorer> matchingScorers = new HashMap<>();

@Override
public void setNextReader(LeafReaderContext readerContext) throws IOException {
matchingIterators.clear();
for (Map.Entry<String, Weight> entry : weights.entrySet()) {
ScorerSupplier ss = entry.getValue().scorerSupplier(readerContext);
if (ss != null) {
Scorer scorer = ss.get(0L);
if (scorer != null) {
matchingIterators.put(entry.getKey(), scorer);
}
}
}
setupScorers(readerContext, weights, matchingScorers);
}

@Override
public void process(HitContext hitContext) throws IOException {
Map<String, Float> matches = new LinkedHashMap<>();
int doc = hitContext.docId();
for (Map.Entry<String, Scorer> entry : matchingIterators.entrySet()) {
int docId = hitContext.docId();
for (Map.Entry<String, Scorer> entry : matchingScorers.entrySet()) {
Scorer scorer = entry.getValue();
if (scorer.iterator().docID() < doc) {
scorer.iterator().advance(doc);
if (scorer.iterator().docID() < docId) {
scorer.iterator().advance(docId);
}
if (scorer.iterator().docID() == doc) {
if (scorer.iterator().docID() == docId) {
matches.put(entry.getKey(), scorer.score());
}
}
Expand All @@ -110,4 +77,52 @@ public void process(HitContext hitContext) throws IOException {
};
}

private FetchSubPhaseProcessor createNonScoringProcessor(Map<String, Weight> weights) {
return new FetchSubPhaseProcessor() {
final Map<String, Bits> matchingBits = new HashMap<>();

@Override
public void setNextReader(LeafReaderContext readerContext) throws IOException {
setupMatchingBits(readerContext, weights, matchingBits);
}

@Override
public void process(HitContext hitContext) {
List<String> matches = new ArrayList<>();
int docId = hitContext.docId();
for (Map.Entry<String, Bits> entry : matchingBits.entrySet()) {
if (entry.getValue().get(docId)) {
matches.add(entry.getKey());
}
}
hitContext.hit().matchedQueries(matches.toArray(new String[0]));
}
};
}

private void setupScorers(LeafReaderContext readerContext, Map<String, Weight> weights, Map<String, Scorer> scorers)
throws IOException {
scorers.clear();
for (Map.Entry<String, Weight> entry : weights.entrySet()) {
ScorerSupplier scorerSupplier = entry.getValue().scorerSupplier(readerContext);
if (scorerSupplier != null) {
Scorer scorer = scorerSupplier.get(0L);
if (scorer != null) {
scorers.put(entry.getKey(), scorer);
}
}
}
}

private void setupMatchingBits(LeafReaderContext readerContext, Map<String, Weight> weights, Map<String, Bits> bitsMap)
throws IOException {
bitsMap.clear();
for (Map.Entry<String, Weight> entry : weights.entrySet()) {
ScorerSupplier scorerSupplier = entry.getValue().scorerSupplier(readerContext);
if (scorerSupplier != null) {
Bits bits = Lucene.asSequentialAccessBits(readerContext.reader().maxDoc(), scorerSupplier);
bitsMap.put(entry.getKey(), bits);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,28 @@ public final void assignRescoreDocIds(RescoreDocIds rescoreDocIds) {

public abstract boolean trackScores();

public abstract SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore);
/**
* Determines whether named queries' scores should be included in the search results.
* By default, this is set to return false, indicating that scores from named queries are not included.
*
* @param includeNamedQueriesScore true to include scores from named queries, false otherwise.
*/
public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) {
// Default implementation does nothing and returns this for chaining.
// Implementations of SearchContext should override this method to actually store the value.
return this;
}

public abstract boolean includeNamedQueriesScore();
/**
* Checks if scores from named queries are included in the search results.
*
* @return true if scores from named queries are included, false otherwise.
*/
public boolean includeNamedQueriesScore() {
// Default implementation returns false.
// Implementations of SearchContext should override this method to return the actual value.
return false;
}

public abstract SearchContext trackTotalHitsUpTo(int trackTotalHits);

Expand Down

0 comments on commit bb0c0d8

Please sign in to comment.