Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.util;

import java.io.IOException;

@FunctionalInterface
public interface ThrowingSupplier<T> {
T get() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.common.lucene.search.function.ScriptScoreQuery;
import org.opensearch.common.util.CachedSupplier;
import org.opensearch.common.util.ThrowingSupplier;
import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.approximate.ApproximateScoreQuery;
Expand Down Expand Up @@ -413,7 +414,7 @@ private static TopDocsCollector<?> createCollector(

protected final @Nullable SortAndFormats sortAndFormats;
private final Collector collector;
private final Supplier<TotalHits> totalHitsSupplier;
private final ThrowingSupplier<TotalHits> totalHitsSupplier;
private final Supplier<TopDocs> topDocsSupplier;
private final Supplier<Float> maxScoreSupplier;
private final ScoreDoc searchAfter;
Expand Down Expand Up @@ -468,8 +469,13 @@ private SimpleTopDocsCollectorContext(
totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
hitCount = -1;
} else {
boolean deferShortcutTotalHitCount = deferShortcutTotalHitCount(hasFilterCollector, reader.hasDeletions(), query);
if (deferShortcutTotalHitCount) {
this.hitCount = 0;
} else {
this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query);
}
// implicit total hit counts are valid only when there is no filter collector in the chain
this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query);
if (this.hitCount == -1) {
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
Expand All @@ -478,7 +484,7 @@ private SimpleTopDocsCollectorContext(
// don't compute hit counts via the collector
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> new TotalHits(this.hitCount, TotalHits.Relation.EQUAL_TO);
totalHitsSupplier = shortcutTotalHitCountSupplier(deferShortcutTotalHitCount, this.hitCount, reader, query);
}
}
MaxScoreCollector maxScoreCollector = null;
Expand Down Expand Up @@ -511,6 +517,21 @@ private SimpleTopDocsCollectorContext(
this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector);
}

private ThrowingSupplier<TotalHits> shortcutTotalHitCountSupplier(
boolean deferShortcutTotalHitCount,
int computedHitCount,
IndexReader reader,
Query query
) {
return () -> {
long shortCutTotalHitCount = computedHitCount;
if (deferShortcutTotalHitCount) {
shortCutTotalHitCount = shortcutTotalHitCount(reader, query);
}
return new TotalHits(shortCutTotalHitCount, TotalHits.Relation.EQUAL_TO);
};
}

private class SimpleTopDocsCollectorManager
implements
CollectorManager<Collector, ReduceableSearchResult>,
Expand Down Expand Up @@ -609,7 +630,7 @@ Collector create(Collector in) {
return collector;
}

TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) {
TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) throws IOException {
TotalHits totalHits = null;

if (sortByScore && hasInfMaxScore) {
Expand All @@ -621,7 +642,7 @@ TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final
if (hitCount == -1) {
totalHits = topDocs.totalHits;
} else {
totalHits = new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO);
totalHits = totalHitsSupplier.get();
}
}

Expand Down Expand Up @@ -658,7 +679,7 @@ TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final
}
}

TopDocsAndMaxScore newTopDocs() {
TopDocsAndMaxScore newTopDocs() throws IOException {
TopDocs in = topDocsSupplier.get();
float maxScore = maxScoreSupplier.get();
final TopDocs newTopDocs;
Expand Down Expand Up @@ -765,21 +786,7 @@ void postProcess(QuerySearchResult result) throws IOException {
* -1 otherwise.
*/
static int shortcutTotalHitCount(IndexReader reader, Query query) throws IOException {
while (true) {
// remove wrappers that don't matter for counts
// this is necessary so that we don't only optimize match_all
// queries but also match_all queries that are nested in
// a constant_score query
if (query instanceof ConstantScoreQuery constantScoreQuery) {
query = constantScoreQuery.getQuery();
} else if (query instanceof BoostQuery boostQuery) {
query = boostQuery.getQuery();
} else if (query instanceof ApproximateScoreQuery approximateScoreQuery) {
query = approximateScoreQuery.getOriginalQuery();
} else {
break;
}
}
query = removeWrappersFromQuery(query);
if (query.getClass() == MatchAllDocsQuery.class) {
return reader.numDocs();
} else if (query.getClass() == TermQuery.class && reader.hasDeletions() == false) {
Expand Down Expand Up @@ -817,6 +824,33 @@ static int shortcutTotalHitCount(IndexReader reader, Query query) throws IOExcep
}
}

static boolean deferShortcutTotalHitCount(boolean hasFilterCollector, boolean hasDeletions, Query query) {
if (hasFilterCollector) {
return false;
}
query = removeWrappersFromQuery(query);
return (query.getClass() == TermQuery.class && hasDeletions == false);
}

static Query removeWrappersFromQuery(Query query) {
while (true) {
// remove wrappers that don't matter for counts
// this is necessary so that we don't only optimize match_all
// queries but also match_all queries that are nested in
// a constant_score query
if (query instanceof ConstantScoreQuery constantScoreQuery) {
query = constantScoreQuery.getQuery();
} else if (query instanceof BoostQuery boostQuery) {
query = boostQuery.getQuery();
} else if (query instanceof ApproximateScoreQuery approximateScoreQuery) {
query = approximateScoreQuery.getOriginalQuery();
} else {
break;
}
}
return query;
}
Comment on lines +827 to +852
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Search for test files related to TopDocsCollectorContext
fd -type f -name "*TopDocsCollectorContext*" -e java

Repository: opensearch-project/OpenSearch

Length of output: 243


🏁 Script executed:

# Search for test files that might test this class (using various naming patterns)
fd -type f -path "*/test*" -name "*TopDocs*" -e java

Repository: opensearch-project/OpenSearch

Length of output: 243


🏁 Script executed:

# Look at the directory structure to find test locations
find . -type d -name "test*" | head -20

Repository: opensearch-project/OpenSearch

Length of output: 1406


🏁 Script executed:

# Find TopDocsCollectorContext source file first
find . -path "*/server/src/main/java/*" -name "TopDocsCollectorContext.java"

Repository: opensearch-project/OpenSearch

Length of output: 152


🏁 Script executed:

# Search for any test file mentioning TopDocsCollectorContext
find . -path "*/test/*" -name "*TopDocsCollectorContext*"

Repository: opensearch-project/OpenSearch

Length of output: 55


🏁 Script executed:

# Look for test files in the server module
find ./server/src/test -name "*.java" | grep -i "topdocs"

Repository: opensearch-project/OpenSearch

Length of output: 163


🏁 Script executed:

# Search for references to TopDocsCollectorContext in test files
grep -r "TopDocsCollectorContext" ./server/src/test --include="*.java"

Repository: opensearch-project/OpenSearch

Length of output: 960


🏁 Script executed:

# Search for the specific method names mentioned in the review
grep -r "deferShortcutTotalHitCount\|removeWrappersFromQuery\|shortcutTotalHitCountSupplier" ./server/src/test --include="*.java"

Repository: opensearch-project/OpenSearch

Length of output: 55


🏁 Script executed:

# Check if TopDocsCollectorContextTests.java exists anywhere
find . -name "TopDocsCollectorContextTests.java" -o -name "TopDocsCollectorContextTest.java"

Repository: opensearch-project/OpenSearch

Length of output: 55


🏁 Script executed:

# Read the source file to understand the context and methods
cat -n ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java | head -100

Repository: opensearch-project/OpenSearch

Length of output: 4854


🏁 Script executed:

# Get line count to understand file size
wc -l ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 156


🏁 Script executed:

# Look specifically at the methods mentioned in lines 827-852
sed -n '800,860p' ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 2835


🏁 Script executed:

# Search for where deferShortcutTotalHitCount is called
grep -n "deferShortcutTotalHitCount" ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 615


🏁 Script executed:

# Search for shortcutTotalHitCountSupplier
grep -n "shortcutTotalHitCountSupplier" ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 284


🏁 Script executed:

# Get context around the usage of these methods
sed -n '700,850p' ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 6920


🏁 Script executed:

# Check if there are any tests that specifically test the three methods
grep -r "shortcutTotalHitCount\|removeWrappersFromQuery" ./server/src/test --include="*.java" -l

Repository: opensearch-project/OpenSearch

Length of output: 55


🏁 Script executed:

# Look at the shortcutTotalHitCountSupplier method definition
sed -n '520,540p' ./server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java

Repository: opensearch-project/OpenSearch

Length of output: 987


Add unit tests for hit count optimization logic.

The helper methods deferShortcutTotalHitCount, removeWrappersFromQuery, and shortcutTotalHitCountSupplier handle critical search path optimization for hit counting but lack dedicated unit test coverage. Unit tests should verify:

  1. deferShortcutTotalHitCount correctly returns true only for unwrapped TermQuery without deletions and without filter collectors
  2. removeWrappersFromQuery properly unwraps nested ConstantScoreQuery, BoostQuery, and ApproximateScoreQuery combinations
  3. shortcutTotalHitCountSupplier correctly defers or computes hit counts based on the deferred flag


/**
* Creates a {@link TopDocsCollectorContext} from the provided <code>searchContext</code>.
* @param hasFilterCollector True if the collector chain contains at least one collector that can filters document.
Expand Down
Loading